Update to new version of framework

This commit is contained in:
xuyige 2018-09-30 21:24:05 +08:00
parent 0b86d7cf2b
commit 91f3d97ace
18 changed files with 208 additions and 62 deletions

View File

@ -69,6 +69,6 @@ class Batch(object):
else: else:
batch[name] = torch.stack(tensor_list, dim=0) batch[name] = torch.stack(tensor_list, dim=0)
self.curidx += endidx self.curidx = endidx
return batch_x, batch_y return batch_x, batch_y

View File

@ -144,6 +144,15 @@ class DataSet(list):
else: else:
self.convert(raw_data) self.convert(raw_data)
def load_raw(self, raw_data, vocabs):
"""
:param raw_data:
:param vocabs:
:return:
"""
self.convert_for_infer(raw_data, vocabs)
def split(self, ratio, shuffle=True): def split(self, ratio, shuffle=True):
"""Train/dev splitting """Train/dev splitting

View File

@ -38,14 +38,19 @@ class SeqLabelEvaluator(Evaluator):
def __call__(self, predict, truth): def __call__(self, predict, truth):
""" """
:param predict: list of tensors, the network outputs from all batches. :param predict: list of List, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y. :param truth: list of dict, the ground truths from all batch_y.
:return accuracy: :return accuracy:
""" """
truth = [item["truth"] for item in truth] truth = [item["truth"] for item in truth]
truth = torch.cat(truth).view(-1, ) total_correct, total_count= 0., 0.
results = torch.Tensor(predict).view(-1, ) for x, y in zip(predict, truth):
accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0] mask = torch.Tensor(x).ge(1)
correct = torch.sum(torch.Tensor(x) * mask.float() == (y * mask.long()).float())
correct -= torch.sum(torch.Tensor(x).le(0))
total_correct += float(correct)
total_count += float(torch.sum(mask))
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)} return {"accuracy": float(accuracy)}

View File

@ -34,7 +34,7 @@ class Predictor(object):
"""Perform inference using the trained model. """Perform inference using the trained model.
:param network: a PyTorch model (cpu) :param network: a PyTorch model (cpu)
:param data: list of list of strings, [num_examples, seq_len] :param data: a DataSet object.
:return: list of list of strings, [num_examples, tag_seq_length] :return: list of list of strings, [num_examples, tag_seq_length]
""" """
# transform strings into DataSet object # transform strings into DataSet object

View File

@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name):
:param pickle_path: str, the directory where the pickle file is to be saved :param pickle_path: str, the directory where the pickle file is to be saved
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
""" """
if not os.path.exists(pickle_path):
os.mkdir(pickle_path)
print("make dir {} before saving pickle file".format(pickle_path))
with open(os.path.join(pickle_path, file_name), "wb") as f: with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f) _pickle.dump(obj, f)
print("{} saved in {}".format(file_name, pickle_path)) print("{} saved in {}".format(file_name, pickle_path))

View File

@ -4,6 +4,8 @@ from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
""" """
mapping from model name to [URL, file_name.class_name, model_pickle_name] mapping from model name to [URL, file_name.class_name, model_pickle_name]
@ -76,6 +78,8 @@ class FastNLP(object):
self.model_dir = model_dir self.model_dir = model_dir
self.model = None self.model = None
self.infer_type = None # "seq_label"/"text_class" self.infer_type = None # "seq_label"/"text_class"
self.word_vocab = None
self.label_vocab = None
def load(self, model_name, config_file="config", section_name="model"): def load(self, model_name, config_file="config", section_name="model"):
""" """
@ -100,10 +104,10 @@ class FastNLP(object):
print("Restore model hyper-parameters {}".format(str(model_args.data))) print("Restore model hyper-parameters {}".format(str(model_args.data)))
# fetch dictionary size and number of labels from pickle files # fetch dictionary size and number of labels from pickle files
word_vocab = load_pickle(self.model_dir, "word2id.pkl") self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(word_vocab) model_args["vocab_size"] = len(self.word_vocab)
label_vocab = load_pickle(self.model_dir, "class2id.pkl") self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
model_args["num_classes"] = len(label_vocab) model_args["num_classes"] = len(self.label_vocab)
# Construct the model # Construct the model
model = model_class(model_args) model = model_class(model_args)
@ -130,8 +134,11 @@ class FastNLP(object):
# tokenize: list of string ---> 2-D list of string # tokenize: list of string ---> 2-D list of string
infer_input = self.tokenize(raw_input, language="zh") infer_input = self.tokenize(raw_input, language="zh")
# 2-D list of string ---> 2-D list of tags # create DataSet: 2-D list of strings ----> DataSet
results = infer.predict(self.model, infer_input) infer_data = self._create_data_set(infer_input)
# DataSet ---> 2-D list of tags
results = infer.predict(self.model, infer_data)
# 2-D list of tags ---> list of final answers # 2-D list of tags ---> list of final answers
outputs = self._make_output(results, infer_input) outputs = self._make_output(results, infer_input)
@ -154,6 +161,11 @@ class FastNLP(object):
return module return module
def _create_inference(self, model_dir): def _create_inference(self, model_dir):
"""Specify which task to perform.
:param model_dir:
:return:
"""
if self.infer_type == "seq_label": if self.infer_type == "seq_label":
return SeqLabelInfer(model_dir) return SeqLabelInfer(model_dir)
elif self.infer_type == "text_class": elif self.infer_type == "text_class":
@ -161,6 +173,24 @@ class FastNLP(object):
else: else:
raise ValueError("fail to create inference instance") raise ValueError("fail to create inference instance")
def _create_data_set(self, infer_input):
"""Create a DataSet object given the raw inputs.
:param infer_input: 2-D lists of strings
:return data_set: a DataSet object
"""
if self.infer_type == "seq_label":
data_set = SeqLabelDataSet()
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
return data_set
elif self.infer_type == "text_class":
data_set = TextClassifyDataSet()
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
return data_set
else:
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
def _load(self, model_dir, model_name): def _load(self, model_dir, model_name):
# To do # To do
return 0 return 0

View File

@ -18,7 +18,7 @@ class ConfigSaver(object):
:return: The section. :return: The section.
""" """
sect = ConfigSection() sect = ConfigSection()
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) ConfigLoader().load_config(self.file_path, {sect_name: sect})
return sect return sect
def _read_section(self): def _read_section(self):

View File

@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase):
# use batch to iterate dataset # use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False) data_iterator = Batch(data, 2, SeqSampler(), False)
total_data = 0
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.assertEqual(len(batch_x), 2) total_data += batch_x["text"].size(0)
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts))
self.assertTrue(isinstance(batch_x, dict)) self.assertTrue(isinstance(batch_x, dict))
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
self.assertTrue(isinstance(batch_y, dict)) self.assertTrue(isinstance(batch_y, dict))

View File

@ -1,20 +1,42 @@
import sys, os import os
import sys
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path
from fastNLP.core import metrics from fastNLP.core import metrics
# from sklearn import metrics as skmetrics # from sklearn import metrics as skmetrics
import unittest import unittest
import numpy as np
from numpy import random from numpy import random
from fastNLP.core.metrics import SeqLabelEvaluator
import torch
def generate_fake_label(low, high, size): def generate_fake_label(low, high, size):
return random.randint(low, high, size), random.randint(low, high, size) return random.randint(low, high, size), random.randint(low, high, size)
class TestEvaluator(unittest.TestCase):
def test_a(self):
evaluator = SeqLabelEvaluator()
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}]
ans = evaluator(pred, truth)
print(ans)
def test_b(self):
evaluator = SeqLabelEvaluator()
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]]
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}]
ans = evaluator(pred, truth)
print(ans)
class TestMetrics(unittest.TestCase): class TestMetrics(unittest.TestCase):
delta = 1e-5 delta = 1e-5
# test for binary, multiclass, multilabel # test for binary, multiclass, multilabel
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
def test_accuracy_score(self): def test_accuracy_score(self):
for y_true, y_pred in self.fake_data: for y_true, y_pred in self.fake_data:
for normalize in [True, False]: for normalize in [True, False]:
@ -22,7 +44,7 @@ class TestMetrics(unittest.TestCase):
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
# self.assertAlmostEqual(test, ans, delta=self.delta) # self.assertAlmostEqual(test, ans, delta=self.delta)
def test_recall_score(self): def test_recall_score(self):
for y_true, y_pred in self.fake_data: for y_true, y_pred in self.fake_data:
# print(y_true.shape) # print(y_true.shape)
@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase):
# ans = skmetrics.f1_score(y_true, y_pred) # ans = skmetrics.f1_score(y_true, y_pred)
# self.assertAlmostEqual(ans, test, delta=self.delta) # self.assertAlmostEqual(ans, test, delta=self.delta)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -2,9 +2,12 @@ import os
import unittest import unittest
from fastNLP.core.predictor import Predictor from fastNLP.core.predictor import Predictor
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
from fastNLP.core.preprocess import save_pickle from fastNLP.core.preprocess import save_pickle
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.base_loader import BaseLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.models.cnn_text_classification import CNNText
class TestPredictor(unittest.TestCase): class TestPredictor(unittest.TestCase):
@ -28,23 +31,44 @@ class TestPredictor(unittest.TestCase):
vocab = Vocabulary() vocab = Vocabulary()
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
class_vocab = Vocabulary() class_vocab = Vocabulary()
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
os.system("mkdir save") os.system("mkdir save")
save_pickle(class_vocab, "./save/", "class2id.pkl") save_pickle(class_vocab, "./save/", "label2id.pkl")
save_pickle(vocab, "./save/", "word2id.pkl") save_pickle(vocab, "./save/", "word2id.pkl")
model = SeqLabeling(model_args) model = CNNText(model_args)
predictor = Predictor("./save/", task="seq_label") import fastNLP.core.predictor as pre
predictor = Predictor("./save/", pre.text_classify_post_processor)
results = predictor.predict(network=model, data=infer_data) # Load infer data
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
results = predictor.predict(network=model, data=infer_data_set)
self.assertTrue(isinstance(results, list)) self.assertTrue(isinstance(results, list))
self.assertGreater(len(results), 0) self.assertGreater(len(results), 0)
self.assertEqual(len(results), len(infer_data))
for res in results: for res in results:
self.assertTrue(isinstance(res, str))
self.assertTrue(res in class_vocab.word2idx)
del model, predictor, infer_data_set
model = SeqLabeling(model_args)
predictor = Predictor("./save/", pre.seq_label_post_processor)
infer_data_set = SeqLabelDataSet(loader=BaseLoader())
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
results = predictor.predict(network=model, data=infer_data_set)
self.assertTrue(isinstance(results, list))
self.assertEqual(len(results), len(infer_data))
for i in range(len(infer_data)):
res = results[i]
self.assertTrue(isinstance(res, list)) self.assertTrue(isinstance(res, list))
self.assertEqual(len(res), 5) self.assertEqual(len(res), len(infer_data[i]))
self.assertTrue(isinstance(res[0], str))
os.system("rm -rf save") os.system("rm -rf save")
print("pickle path deleted") print("pickle path deleted")

View File

@ -1,8 +1,9 @@
import os import os
import unittest import unittest
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import SeqLabelDataSet
from fastNLP.core.field import TextField from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.models.sequence_modeling import SeqLabeling
@ -21,7 +22,7 @@ class TestTester(unittest.TestCase):
} }
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 2, "pickle_path": "./save/", "save_loss": True, "batch_size": 2, "pickle_path": "./save/",
"use_cuda": False, "print_every_step": 1} "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()}
train_data = [ train_data = [
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
@ -34,16 +35,17 @@ class TestTester(unittest.TestCase):
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
data_set = DataSet() data_set = SeqLabelDataSet()
for example in train_data: for example in train_data:
text, label = example[0], example[1] text, label = example[0], example[1]
x = TextField(text, False) x = TextField(text, False)
x_len = LabelField(len(text), is_target=False)
y = TextField(label, is_target=True) y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y) ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
data_set.append(ins) data_set.append(ins)
data_set.index_field("word_seq", vocab) data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab) data_set.index_field("truth", label_vocab)
model = SeqLabeling(model_args) model = SeqLabeling(model_args)

View File

@ -1,8 +1,9 @@
import os import os
import unittest import unittest
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import SeqLabelDataSet
from fastNLP.core.field import TextField from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.loss import Loss from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling
class TestTrainer(unittest.TestCase): class TestTrainer(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl", "save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None), "loss": Loss("cross_entropy"),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 10, "vocab_size": 10,
"word_emb_dim": 100, "word_emb_dim": 100,
"rnn_hidden_units": 100, "rnn_hidden_units": 100,
"num_classes": 5 "num_classes": 5,
"evaluator": SeqLabelEvaluator()
} }
trainer = SeqLabelTrainer(**args) trainer = SeqLabelTrainer(**args)
@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase):
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
data_set = DataSet() data_set = SeqLabelDataSet()
for example in train_data: for example in train_data:
text, label = example[0], example[1] text, label = example[0], example[1]
x = TextField(text, False) x = TextField(text, False)
y = TextField(label, is_target=True) x_len = LabelField(len(text), is_target=False)
ins = Instance(word_seq=x, label_seq=y) y = TextField(label, is_target=False)
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
data_set.append(ins) data_set.append(ins)
data_set.index_field("word_seq", vocab) data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab) data_set.index_field("truth", label_vocab)
model = SeqLabeling(args) model = SeqLabeling(args)

View File

@ -9,10 +9,54 @@ input = [1,2,3]
text = "this is text" text = "this is text"
doubles = 0.5 doubles = 0.8
tt = 0.5
test = 105
str = "this is a str"
double = 0.5
[t] [t]
x = "this is an test section" x = "this is an test section"
[test-case-2] [test-case-2]
double = 0.5 double = 0.5
doubles = 0.8
tt = 0.5
test = 105
str = "this is a str"
[another-test]
doubles = 0.8
tt = 0.5
test = 105
str = "this is a str"
double = 0.5
[one-another-test]
doubles = 0.8
tt = 0.5
test = 105
str = "this is a str"
double = 0.5

View File

@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase):
return dict return dict
test_arg = ConfigSection() test_arg = ConfigSection()
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
section = read_section_from_config(os.path.join("./test/loader", "config"), "test") section = read_section_from_config(os.path.join("./test/loader", "config"), "test")

View File

@ -1,3 +1,4 @@
import os
import unittest import unittest
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
@ -14,28 +15,28 @@ class TestDatasetLoader(unittest.TestCase):
def test_case_TokenizeDatasetLoader(self): def test_case_TokenizeDatasetLoader(self):
loader = TokenizeDataSetLoader() loader = TokenizeDataSetLoader()
data = loader.load("test/data_for_tests/", max_seq_len=32) data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
print("pass TokenizeDataSetLoader test!") print("pass TokenizeDataSetLoader test!")
def test_case_POSDatasetLoader(self): def test_case_POSDatasetLoader(self):
loader = POSDataSetLoader() loader = POSDataSetLoader()
data = loader.load() data = loader.load("./test/data_for_tests/people.txt")
datas = loader.load_lines() datas = loader.load_lines("./test/data_for_tests/people.txt")
print("pass POSDataSetLoader test!") print("pass POSDataSetLoader test!")
def test_case_LMDatasetLoader(self): def test_case_LMDatasetLoader(self):
loader = LMDataSetLoader() loader = LMDataSetLoader()
data = loader.load() data = loader.load("./test/data_for_tests/charlm.txt")
datas = loader.load_lines() datas = loader.load_lines("./test/data_for_tests/charlm.txt")
print("pass TokenizeDataSetLoader test!") print("pass TokenizeDataSetLoader test!")
def test_PeopleDailyCorpusLoader(self): def test_PeopleDailyCorpusLoader(self):
loader = PeopleDailyCorpusLoader() loader = PeopleDailyCorpusLoader()
_, _ = loader.load() _, _ = loader.load("./test/data_for_tests/people_daily_raw.txt")
def test_ConllLoader(self): def test_ConllLoader(self):
loader = ConllLoader("./test/data_for_tests/conll_example.txt") loader = ConllLoader()
_ = loader.load() _ = loader.load("./test/data_for_tests/conll_example.txt")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -13,10 +13,10 @@ from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver
data_name = "pku_training.utf8" data_name = "pku_training.utf8"
cws_data_path = "test/data_for_tests/cws_pku_utf_8" cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
pickle_path = "./save/" pickle_path = "./save/"
data_infer_path = "test/data_for_tests/people_infer.txt" data_infer_path = "./test/data_for_tests/people_infer.txt"
config_path = "test/data_for_tests/config" config_path = "./test/data_for_tests/config"
def infer(): def infer():
# Load infer configuration, the same as test # Load infer configuration, the same as test

View File

@ -21,7 +21,7 @@ class TestConfigSaver(unittest.TestCase):
standard_section = ConfigSection() standard_section = ConfigSection()
t_section = ConfigSection() t_section = ConfigSection()
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section})
config_saver = ConfigSaver(config_file_path) config_saver = ConfigSaver(config_file_path)
@ -48,11 +48,11 @@ class TestConfigSaver(unittest.TestCase):
one_another_test_section = ConfigSection() one_another_test_section = ConfigSection()
a_test_case_2_section = ConfigSection() a_test_case_2_section = ConfigSection()
ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, ConfigLoader().load_config(config_file_path, {"test": test_section,
"another-test": another_test_section, "another-test": another_test_section,
"t": at_section, "t": at_section,
"one-another-test": one_another_test_section, "one-another-test": one_another_test_section,
"test-case-2": a_test_case_2_section}) "test-case-2": a_test_case_2_section})
assert test_section == standard_section assert test_section == standard_section
assert at_section == t_section assert at_section == t_section

View File

@ -54,7 +54,7 @@ def mock_cws():
class2id = Vocabulary(need_default=False) class2id = Vocabulary(need_default=False)
label_list = ['B', 'M', 'E', 'S'] label_list = ['B', 'M', 'E', 'S']
class2id.update(label_list) class2id.update(label_list)
save_pickle(class2id, "./mock/", "class2id.pkl") save_pickle(class2id, "./mock/", "label2id.pkl")
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
config_file = """ config_file = """
@ -115,7 +115,7 @@ def mock_pos_tag():
idx2label = Vocabulary(need_default=False) idx2label = Vocabulary(need_default=False)
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
idx2label.update(label_list) idx2label.update(label_list)
save_pickle(idx2label, "./mock/", "class2id.pkl") save_pickle(idx2label, "./mock/", "label2id.pkl")
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
config_file = """ config_file = """
@ -163,7 +163,7 @@ def mock_text_classify():
idx2label = Vocabulary(need_default=False) idx2label = Vocabulary(need_default=False)
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
idx2label.update(label_list) idx2label.update(label_list)
save_pickle(idx2label, "./mock/", "class2id.pkl") save_pickle(idx2label, "./mock/", "label2id.pkl")
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
config_file = """ config_file = """