mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
Update to new version of framework
This commit is contained in:
parent
0b86d7cf2b
commit
91f3d97ace
@ -69,6 +69,6 @@ class Batch(object):
|
|||||||
else:
|
else:
|
||||||
batch[name] = torch.stack(tensor_list, dim=0)
|
batch[name] = torch.stack(tensor_list, dim=0)
|
||||||
|
|
||||||
self.curidx += endidx
|
self.curidx = endidx
|
||||||
return batch_x, batch_y
|
return batch_x, batch_y
|
||||||
|
|
||||||
|
@ -144,6 +144,15 @@ class DataSet(list):
|
|||||||
else:
|
else:
|
||||||
self.convert(raw_data)
|
self.convert(raw_data)
|
||||||
|
|
||||||
|
def load_raw(self, raw_data, vocabs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param raw_data:
|
||||||
|
:param vocabs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.convert_for_infer(raw_data, vocabs)
|
||||||
|
|
||||||
def split(self, ratio, shuffle=True):
|
def split(self, ratio, shuffle=True):
|
||||||
"""Train/dev splitting
|
"""Train/dev splitting
|
||||||
|
|
||||||
|
@ -38,14 +38,19 @@ class SeqLabelEvaluator(Evaluator):
|
|||||||
def __call__(self, predict, truth):
|
def __call__(self, predict, truth):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param predict: list of tensors, the network outputs from all batches.
|
:param predict: list of List, the network outputs from all batches.
|
||||||
:param truth: list of dict, the ground truths from all batch_y.
|
:param truth: list of dict, the ground truths from all batch_y.
|
||||||
:return accuracy:
|
:return accuracy:
|
||||||
"""
|
"""
|
||||||
truth = [item["truth"] for item in truth]
|
truth = [item["truth"] for item in truth]
|
||||||
truth = torch.cat(truth).view(-1, )
|
total_correct, total_count= 0., 0.
|
||||||
results = torch.Tensor(predict).view(-1, )
|
for x, y in zip(predict, truth):
|
||||||
accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0]
|
mask = torch.Tensor(x).ge(1)
|
||||||
|
correct = torch.sum(torch.Tensor(x) * mask.float() == (y * mask.long()).float())
|
||||||
|
correct -= torch.sum(torch.Tensor(x).le(0))
|
||||||
|
total_correct += float(correct)
|
||||||
|
total_count += float(torch.sum(mask))
|
||||||
|
accuracy = total_correct / total_count
|
||||||
return {"accuracy": float(accuracy)}
|
return {"accuracy": float(accuracy)}
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class Predictor(object):
|
|||||||
"""Perform inference using the trained model.
|
"""Perform inference using the trained model.
|
||||||
|
|
||||||
:param network: a PyTorch model (cpu)
|
:param network: a PyTorch model (cpu)
|
||||||
:param data: list of list of strings, [num_examples, seq_len]
|
:param data: a DataSet object.
|
||||||
:return: list of list of strings, [num_examples, tag_seq_length]
|
:return: list of list of strings, [num_examples, tag_seq_length]
|
||||||
"""
|
"""
|
||||||
# transform strings into DataSet object
|
# transform strings into DataSet object
|
||||||
|
@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name):
|
|||||||
:param pickle_path: str, the directory where the pickle file is to be saved
|
:param pickle_path: str, the directory where the pickle file is to be saved
|
||||||
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
|
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
|
||||||
"""
|
"""
|
||||||
|
if not os.path.exists(pickle_path):
|
||||||
|
os.mkdir(pickle_path)
|
||||||
|
print("make dir {} before saving pickle file".format(pickle_path))
|
||||||
with open(os.path.join(pickle_path, file_name), "wb") as f:
|
with open(os.path.join(pickle_path, file_name), "wb") as f:
|
||||||
_pickle.dump(obj, f)
|
_pickle.dump(obj, f)
|
||||||
print("{} saved in {}".format(file_name, pickle_path))
|
print("{} saved in {}".format(file_name, pickle_path))
|
||||||
|
@ -4,6 +4,8 @@ from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
|||||||
from fastNLP.core.preprocess import load_pickle
|
from fastNLP.core.preprocess import load_pickle
|
||||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||||
from fastNLP.loader.model_loader import ModelLoader
|
from fastNLP.loader.model_loader import ModelLoader
|
||||||
|
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mapping from model name to [URL, file_name.class_name, model_pickle_name]
|
mapping from model name to [URL, file_name.class_name, model_pickle_name]
|
||||||
@ -76,6 +78,8 @@ class FastNLP(object):
|
|||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
self.model = None
|
self.model = None
|
||||||
self.infer_type = None # "seq_label"/"text_class"
|
self.infer_type = None # "seq_label"/"text_class"
|
||||||
|
self.word_vocab = None
|
||||||
|
self.label_vocab = None
|
||||||
|
|
||||||
def load(self, model_name, config_file="config", section_name="model"):
|
def load(self, model_name, config_file="config", section_name="model"):
|
||||||
"""
|
"""
|
||||||
@ -100,10 +104,10 @@ class FastNLP(object):
|
|||||||
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
||||||
|
|
||||||
# fetch dictionary size and number of labels from pickle files
|
# fetch dictionary size and number of labels from pickle files
|
||||||
word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||||
model_args["vocab_size"] = len(word_vocab)
|
model_args["vocab_size"] = len(self.word_vocab)
|
||||||
label_vocab = load_pickle(self.model_dir, "class2id.pkl")
|
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
|
||||||
model_args["num_classes"] = len(label_vocab)
|
model_args["num_classes"] = len(self.label_vocab)
|
||||||
|
|
||||||
# Construct the model
|
# Construct the model
|
||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
@ -130,8 +134,11 @@ class FastNLP(object):
|
|||||||
# tokenize: list of string ---> 2-D list of string
|
# tokenize: list of string ---> 2-D list of string
|
||||||
infer_input = self.tokenize(raw_input, language="zh")
|
infer_input = self.tokenize(raw_input, language="zh")
|
||||||
|
|
||||||
# 2-D list of string ---> 2-D list of tags
|
# create DataSet: 2-D list of strings ----> DataSet
|
||||||
results = infer.predict(self.model, infer_input)
|
infer_data = self._create_data_set(infer_input)
|
||||||
|
|
||||||
|
# DataSet ---> 2-D list of tags
|
||||||
|
results = infer.predict(self.model, infer_data)
|
||||||
|
|
||||||
# 2-D list of tags ---> list of final answers
|
# 2-D list of tags ---> list of final answers
|
||||||
outputs = self._make_output(results, infer_input)
|
outputs = self._make_output(results, infer_input)
|
||||||
@ -154,6 +161,11 @@ class FastNLP(object):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
def _create_inference(self, model_dir):
|
def _create_inference(self, model_dir):
|
||||||
|
"""Specify which task to perform.
|
||||||
|
|
||||||
|
:param model_dir:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if self.infer_type == "seq_label":
|
if self.infer_type == "seq_label":
|
||||||
return SeqLabelInfer(model_dir)
|
return SeqLabelInfer(model_dir)
|
||||||
elif self.infer_type == "text_class":
|
elif self.infer_type == "text_class":
|
||||||
@ -161,6 +173,24 @@ class FastNLP(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("fail to create inference instance")
|
raise ValueError("fail to create inference instance")
|
||||||
|
|
||||||
|
def _create_data_set(self, infer_input):
|
||||||
|
"""Create a DataSet object given the raw inputs.
|
||||||
|
|
||||||
|
:param infer_input: 2-D lists of strings
|
||||||
|
:return data_set: a DataSet object
|
||||||
|
"""
|
||||||
|
if self.infer_type == "seq_label":
|
||||||
|
data_set = SeqLabelDataSet()
|
||||||
|
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||||
|
return data_set
|
||||||
|
elif self.infer_type == "text_class":
|
||||||
|
data_set = TextClassifyDataSet()
|
||||||
|
data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
|
||||||
|
return data_set
|
||||||
|
else:
|
||||||
|
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
|
||||||
|
|
||||||
|
|
||||||
def _load(self, model_dir, model_name):
|
def _load(self, model_dir, model_name):
|
||||||
# To do
|
# To do
|
||||||
return 0
|
return 0
|
||||||
|
@ -18,7 +18,7 @@ class ConfigSaver(object):
|
|||||||
:return: The section.
|
:return: The section.
|
||||||
"""
|
"""
|
||||||
sect = ConfigSection()
|
sect = ConfigSection()
|
||||||
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect})
|
ConfigLoader().load_config(self.file_path, {sect_name: sect})
|
||||||
return sect
|
return sect
|
||||||
|
|
||||||
def _read_section(self):
|
def _read_section(self):
|
||||||
|
@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase):
|
|||||||
|
|
||||||
# use batch to iterate dataset
|
# use batch to iterate dataset
|
||||||
data_iterator = Batch(data, 2, SeqSampler(), False)
|
data_iterator = Batch(data, 2, SeqSampler(), False)
|
||||||
|
total_data = 0
|
||||||
for batch_x, batch_y in data_iterator:
|
for batch_x, batch_y in data_iterator:
|
||||||
self.assertEqual(len(batch_x), 2)
|
total_data += batch_x["text"].size(0)
|
||||||
|
self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts))
|
||||||
self.assertTrue(isinstance(batch_x, dict))
|
self.assertTrue(isinstance(batch_x, dict))
|
||||||
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
|
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
|
||||||
self.assertTrue(isinstance(batch_y, dict))
|
self.assertTrue(isinstance(batch_y, dict))
|
||||||
|
@ -1,20 +1,42 @@
|
|||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path
|
sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path
|
||||||
|
|
||||||
from fastNLP.core import metrics
|
from fastNLP.core import metrics
|
||||||
# from sklearn import metrics as skmetrics
|
# from sklearn import metrics as skmetrics
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
|
||||||
from numpy import random
|
from numpy import random
|
||||||
|
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def generate_fake_label(low, high, size):
|
def generate_fake_label(low, high, size):
|
||||||
return random.randint(low, high, size), random.randint(low, high, size)
|
return random.randint(low, high, size), random.randint(low, high, size)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvaluator(unittest.TestCase):
|
||||||
|
def test_a(self):
|
||||||
|
evaluator = SeqLabelEvaluator()
|
||||||
|
pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]
|
||||||
|
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}]
|
||||||
|
ans = evaluator(pred, truth)
|
||||||
|
print(ans)
|
||||||
|
|
||||||
|
def test_b(self):
|
||||||
|
evaluator = SeqLabelEvaluator()
|
||||||
|
pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]]
|
||||||
|
truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}]
|
||||||
|
ans = evaluator(pred, truth)
|
||||||
|
print(ans)
|
||||||
|
|
||||||
|
|
||||||
class TestMetrics(unittest.TestCase):
|
class TestMetrics(unittest.TestCase):
|
||||||
delta = 1e-5
|
delta = 1e-5
|
||||||
# test for binary, multiclass, multilabel
|
# test for binary, multiclass, multilabel
|
||||||
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
|
data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)]
|
||||||
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
|
fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types]
|
||||||
|
|
||||||
def test_accuracy_score(self):
|
def test_accuracy_score(self):
|
||||||
for y_true, y_pred in self.fake_data:
|
for y_true, y_pred in self.fake_data:
|
||||||
for normalize in [True, False]:
|
for normalize in [True, False]:
|
||||||
@ -22,7 +44,7 @@ class TestMetrics(unittest.TestCase):
|
|||||||
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
|
test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
|
||||||
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
|
# ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight)
|
||||||
# self.assertAlmostEqual(test, ans, delta=self.delta)
|
# self.assertAlmostEqual(test, ans, delta=self.delta)
|
||||||
|
|
||||||
def test_recall_score(self):
|
def test_recall_score(self):
|
||||||
for y_true, y_pred in self.fake_data:
|
for y_true, y_pred in self.fake_data:
|
||||||
# print(y_true.shape)
|
# print(y_true.shape)
|
||||||
@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase):
|
|||||||
# ans = skmetrics.f1_score(y_true, y_pred)
|
# ans = skmetrics.f1_score(y_true, y_pred)
|
||||||
# self.assertAlmostEqual(ans, test, delta=self.delta)
|
# self.assertAlmostEqual(ans, test, delta=self.delta)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2,9 +2,12 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from fastNLP.core.predictor import Predictor
|
from fastNLP.core.predictor import Predictor
|
||||||
|
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
|
||||||
from fastNLP.core.preprocess import save_pickle
|
from fastNLP.core.preprocess import save_pickle
|
||||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
|
||||||
from fastNLP.core.vocabulary import Vocabulary
|
from fastNLP.core.vocabulary import Vocabulary
|
||||||
|
from fastNLP.loader.base_loader import BaseLoader
|
||||||
|
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||||
|
from fastNLP.models.cnn_text_classification import CNNText
|
||||||
|
|
||||||
|
|
||||||
class TestPredictor(unittest.TestCase):
|
class TestPredictor(unittest.TestCase):
|
||||||
@ -28,23 +31,44 @@ class TestPredictor(unittest.TestCase):
|
|||||||
vocab = Vocabulary()
|
vocab = Vocabulary()
|
||||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||||
class_vocab = Vocabulary()
|
class_vocab = Vocabulary()
|
||||||
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4}
|
class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
|
||||||
|
|
||||||
os.system("mkdir save")
|
os.system("mkdir save")
|
||||||
save_pickle(class_vocab, "./save/", "class2id.pkl")
|
save_pickle(class_vocab, "./save/", "label2id.pkl")
|
||||||
save_pickle(vocab, "./save/", "word2id.pkl")
|
save_pickle(vocab, "./save/", "word2id.pkl")
|
||||||
|
|
||||||
model = SeqLabeling(model_args)
|
model = CNNText(model_args)
|
||||||
predictor = Predictor("./save/", task="seq_label")
|
import fastNLP.core.predictor as pre
|
||||||
|
predictor = Predictor("./save/", pre.text_classify_post_processor)
|
||||||
|
|
||||||
results = predictor.predict(network=model, data=infer_data)
|
# Load infer data
|
||||||
|
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
|
||||||
|
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||||
|
|
||||||
|
results = predictor.predict(network=model, data=infer_data_set)
|
||||||
|
|
||||||
self.assertTrue(isinstance(results, list))
|
self.assertTrue(isinstance(results, list))
|
||||||
self.assertGreater(len(results), 0)
|
self.assertGreater(len(results), 0)
|
||||||
|
self.assertEqual(len(results), len(infer_data))
|
||||||
for res in results:
|
for res in results:
|
||||||
|
self.assertTrue(isinstance(res, str))
|
||||||
|
self.assertTrue(res in class_vocab.word2idx)
|
||||||
|
|
||||||
|
del model, predictor, infer_data_set
|
||||||
|
|
||||||
|
model = SeqLabeling(model_args)
|
||||||
|
predictor = Predictor("./save/", pre.seq_label_post_processor)
|
||||||
|
|
||||||
|
infer_data_set = SeqLabelDataSet(loader=BaseLoader())
|
||||||
|
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||||
|
|
||||||
|
results = predictor.predict(network=model, data=infer_data_set)
|
||||||
|
self.assertTrue(isinstance(results, list))
|
||||||
|
self.assertEqual(len(results), len(infer_data))
|
||||||
|
for i in range(len(infer_data)):
|
||||||
|
res = results[i]
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertTrue(isinstance(res, list))
|
||||||
self.assertEqual(len(res), 5)
|
self.assertEqual(len(res), len(infer_data[i]))
|
||||||
self.assertTrue(isinstance(res[0], str))
|
|
||||||
|
|
||||||
os.system("rm -rf save")
|
os.system("rm -rf save")
|
||||||
print("pickle path deleted")
|
print("pickle path deleted")
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from fastNLP.core.dataset import DataSet
|
from fastNLP.core.dataset import SeqLabelDataSet
|
||||||
from fastNLP.core.field import TextField
|
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||||
|
from fastNLP.core.field import TextField, LabelField
|
||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
from fastNLP.core.tester import SeqLabelTester
|
from fastNLP.core.tester import SeqLabelTester
|
||||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||||
@ -21,7 +22,7 @@ class TestTester(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
|
||||||
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
|
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
|
||||||
"use_cuda": False, "print_every_step": 1}
|
"use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()}
|
||||||
|
|
||||||
train_data = [
|
train_data = [
|
||||||
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
|
||||||
@ -34,16 +35,17 @@ class TestTester(unittest.TestCase):
|
|||||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||||
|
|
||||||
data_set = DataSet()
|
data_set = SeqLabelDataSet()
|
||||||
for example in train_data:
|
for example in train_data:
|
||||||
text, label = example[0], example[1]
|
text, label = example[0], example[1]
|
||||||
x = TextField(text, False)
|
x = TextField(text, False)
|
||||||
|
x_len = LabelField(len(text), is_target=False)
|
||||||
y = TextField(label, is_target=True)
|
y = TextField(label, is_target=True)
|
||||||
ins = Instance(word_seq=x, label_seq=y)
|
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
|
||||||
data_set.append(ins)
|
data_set.append(ins)
|
||||||
|
|
||||||
data_set.index_field("word_seq", vocab)
|
data_set.index_field("word_seq", vocab)
|
||||||
data_set.index_field("label_seq", label_vocab)
|
data_set.index_field("truth", label_vocab)
|
||||||
|
|
||||||
model = SeqLabeling(model_args)
|
model = SeqLabeling(model_args)
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from fastNLP.core.dataset import DataSet
|
from fastNLP.core.dataset import SeqLabelDataSet
|
||||||
from fastNLP.core.field import TextField
|
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||||
|
from fastNLP.core.field import TextField, LabelField
|
||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
from fastNLP.core.loss import Loss
|
from fastNLP.core.loss import Loss
|
||||||
from fastNLP.core.optimizer import Optimizer
|
from fastNLP.core.optimizer import Optimizer
|
||||||
@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling
|
|||||||
|
|
||||||
class TestTrainer(unittest.TestCase):
|
class TestTrainer(unittest.TestCase):
|
||||||
def test_case_1(self):
|
def test_case_1(self):
|
||||||
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
|
args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
|
||||||
"save_best_dev": True, "model_name": "default_model_name.pkl",
|
"save_best_dev": True, "model_name": "default_model_name.pkl",
|
||||||
"loss": Loss(None),
|
"loss": Loss("cross_entropy"),
|
||||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||||
"vocab_size": 10,
|
"vocab_size": 10,
|
||||||
"word_emb_dim": 100,
|
"word_emb_dim": 100,
|
||||||
"rnn_hidden_units": 100,
|
"rnn_hidden_units": 100,
|
||||||
"num_classes": 5
|
"num_classes": 5,
|
||||||
|
"evaluator": SeqLabelEvaluator()
|
||||||
}
|
}
|
||||||
trainer = SeqLabelTrainer(**args)
|
trainer = SeqLabelTrainer(**args)
|
||||||
|
|
||||||
@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase):
|
|||||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||||
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
|
||||||
|
|
||||||
data_set = DataSet()
|
data_set = SeqLabelDataSet()
|
||||||
for example in train_data:
|
for example in train_data:
|
||||||
text, label = example[0], example[1]
|
text, label = example[0], example[1]
|
||||||
x = TextField(text, False)
|
x = TextField(text, False)
|
||||||
y = TextField(label, is_target=True)
|
x_len = LabelField(len(text), is_target=False)
|
||||||
ins = Instance(word_seq=x, label_seq=y)
|
y = TextField(label, is_target=False)
|
||||||
|
ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
|
||||||
data_set.append(ins)
|
data_set.append(ins)
|
||||||
|
|
||||||
data_set.index_field("word_seq", vocab)
|
data_set.index_field("word_seq", vocab)
|
||||||
data_set.index_field("label_seq", label_vocab)
|
data_set.index_field("truth", label_vocab)
|
||||||
|
|
||||||
model = SeqLabeling(args)
|
model = SeqLabeling(args)
|
||||||
|
|
||||||
|
@ -9,10 +9,54 @@ input = [1,2,3]
|
|||||||
|
|
||||||
text = "this is text"
|
text = "this is text"
|
||||||
|
|
||||||
doubles = 0.5
|
doubles = 0.8
|
||||||
|
|
||||||
|
tt = 0.5
|
||||||
|
|
||||||
|
test = 105
|
||||||
|
|
||||||
|
str = "this is a str"
|
||||||
|
|
||||||
|
double = 0.5
|
||||||
|
|
||||||
|
|
||||||
[t]
|
[t]
|
||||||
x = "this is an test section"
|
x = "this is an test section"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[test-case-2]
|
[test-case-2]
|
||||||
double = 0.5
|
double = 0.5
|
||||||
|
|
||||||
|
doubles = 0.8
|
||||||
|
|
||||||
|
tt = 0.5
|
||||||
|
|
||||||
|
test = 105
|
||||||
|
|
||||||
|
str = "this is a str"
|
||||||
|
|
||||||
|
[another-test]
|
||||||
|
doubles = 0.8
|
||||||
|
|
||||||
|
tt = 0.5
|
||||||
|
|
||||||
|
test = 105
|
||||||
|
|
||||||
|
str = "this is a str"
|
||||||
|
|
||||||
|
double = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
[one-another-test]
|
||||||
|
doubles = 0.8
|
||||||
|
|
||||||
|
tt = 0.5
|
||||||
|
|
||||||
|
test = 105
|
||||||
|
|
||||||
|
str = "this is a str"
|
||||||
|
|
||||||
|
double = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase):
|
|||||||
return dict
|
return dict
|
||||||
|
|
||||||
test_arg = ConfigSection()
|
test_arg = ConfigSection()
|
||||||
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
|
ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
|
||||||
|
|
||||||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test")
|
section = read_section_from_config(os.path.join("./test/loader", "config"), "test")
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
||||||
@ -14,28 +15,28 @@ class TestDatasetLoader(unittest.TestCase):
|
|||||||
|
|
||||||
def test_case_TokenizeDatasetLoader(self):
|
def test_case_TokenizeDatasetLoader(self):
|
||||||
loader = TokenizeDataSetLoader()
|
loader = TokenizeDataSetLoader()
|
||||||
data = loader.load("test/data_for_tests/", max_seq_len=32)
|
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
|
||||||
print("pass TokenizeDataSetLoader test!")
|
print("pass TokenizeDataSetLoader test!")
|
||||||
|
|
||||||
def test_case_POSDatasetLoader(self):
|
def test_case_POSDatasetLoader(self):
|
||||||
loader = POSDataSetLoader()
|
loader = POSDataSetLoader()
|
||||||
data = loader.load()
|
data = loader.load("./test/data_for_tests/people.txt")
|
||||||
datas = loader.load_lines()
|
datas = loader.load_lines("./test/data_for_tests/people.txt")
|
||||||
print("pass POSDataSetLoader test!")
|
print("pass POSDataSetLoader test!")
|
||||||
|
|
||||||
def test_case_LMDatasetLoader(self):
|
def test_case_LMDatasetLoader(self):
|
||||||
loader = LMDataSetLoader()
|
loader = LMDataSetLoader()
|
||||||
data = loader.load()
|
data = loader.load("./test/data_for_tests/charlm.txt")
|
||||||
datas = loader.load_lines()
|
datas = loader.load_lines("./test/data_for_tests/charlm.txt")
|
||||||
print("pass TokenizeDataSetLoader test!")
|
print("pass TokenizeDataSetLoader test!")
|
||||||
|
|
||||||
def test_PeopleDailyCorpusLoader(self):
|
def test_PeopleDailyCorpusLoader(self):
|
||||||
loader = PeopleDailyCorpusLoader()
|
loader = PeopleDailyCorpusLoader()
|
||||||
_, _ = loader.load()
|
_, _ = loader.load("./test/data_for_tests/people_daily_raw.txt")
|
||||||
|
|
||||||
def test_ConllLoader(self):
|
def test_ConllLoader(self):
|
||||||
loader = ConllLoader("./test/data_for_tests/conll_example.txt")
|
loader = ConllLoader()
|
||||||
_ = loader.load()
|
_ = loader.load("./test/data_for_tests/conll_example.txt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -13,10 +13,10 @@ from fastNLP.models.sequence_modeling import SeqLabeling
|
|||||||
from fastNLP.saver.model_saver import ModelSaver
|
from fastNLP.saver.model_saver import ModelSaver
|
||||||
|
|
||||||
data_name = "pku_training.utf8"
|
data_name = "pku_training.utf8"
|
||||||
cws_data_path = "test/data_for_tests/cws_pku_utf_8"
|
cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
|
||||||
pickle_path = "./save/"
|
pickle_path = "./save/"
|
||||||
data_infer_path = "test/data_for_tests/people_infer.txt"
|
data_infer_path = "./test/data_for_tests/people_infer.txt"
|
||||||
config_path = "test/data_for_tests/config"
|
config_path = "./test/data_for_tests/config"
|
||||||
|
|
||||||
def infer():
|
def infer():
|
||||||
# Load infer configuration, the same as test
|
# Load infer configuration, the same as test
|
||||||
|
@ -21,7 +21,7 @@ class TestConfigSaver(unittest.TestCase):
|
|||||||
|
|
||||||
standard_section = ConfigSection()
|
standard_section = ConfigSection()
|
||||||
t_section = ConfigSection()
|
t_section = ConfigSection()
|
||||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section})
|
ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section})
|
||||||
|
|
||||||
config_saver = ConfigSaver(config_file_path)
|
config_saver = ConfigSaver(config_file_path)
|
||||||
|
|
||||||
@ -48,11 +48,11 @@ class TestConfigSaver(unittest.TestCase):
|
|||||||
one_another_test_section = ConfigSection()
|
one_another_test_section = ConfigSection()
|
||||||
a_test_case_2_section = ConfigSection()
|
a_test_case_2_section = ConfigSection()
|
||||||
|
|
||||||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section,
|
ConfigLoader().load_config(config_file_path, {"test": test_section,
|
||||||
"another-test": another_test_section,
|
"another-test": another_test_section,
|
||||||
"t": at_section,
|
"t": at_section,
|
||||||
"one-another-test": one_another_test_section,
|
"one-another-test": one_another_test_section,
|
||||||
"test-case-2": a_test_case_2_section})
|
"test-case-2": a_test_case_2_section})
|
||||||
|
|
||||||
assert test_section == standard_section
|
assert test_section == standard_section
|
||||||
assert at_section == t_section
|
assert at_section == t_section
|
||||||
|
@ -54,7 +54,7 @@ def mock_cws():
|
|||||||
class2id = Vocabulary(need_default=False)
|
class2id = Vocabulary(need_default=False)
|
||||||
label_list = ['B', 'M', 'E', 'S']
|
label_list = ['B', 'M', 'E', 'S']
|
||||||
class2id.update(label_list)
|
class2id.update(label_list)
|
||||||
save_pickle(class2id, "./mock/", "class2id.pkl")
|
save_pickle(class2id, "./mock/", "label2id.pkl")
|
||||||
|
|
||||||
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
|
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
|
||||||
config_file = """
|
config_file = """
|
||||||
@ -115,7 +115,7 @@ def mock_pos_tag():
|
|||||||
idx2label = Vocabulary(need_default=False)
|
idx2label = Vocabulary(need_default=False)
|
||||||
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
|
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
|
||||||
idx2label.update(label_list)
|
idx2label.update(label_list)
|
||||||
save_pickle(idx2label, "./mock/", "class2id.pkl")
|
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||||
|
|
||||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||||
config_file = """
|
config_file = """
|
||||||
@ -163,7 +163,7 @@ def mock_text_classify():
|
|||||||
idx2label = Vocabulary(need_default=False)
|
idx2label = Vocabulary(need_default=False)
|
||||||
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
|
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
|
||||||
idx2label.update(label_list)
|
idx2label.update(label_list)
|
||||||
save_pickle(idx2label, "./mock/", "class2id.pkl")
|
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||||
|
|
||||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||||
config_file = """
|
config_file = """
|
||||||
|
Loading…
Reference in New Issue
Block a user