mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
Merge Preprocessor into DataSet.
- DataSet's __init__ takes a function as argument, rather than class object - Preprocessor is about to remove. Don't use anymore. - Remove cross_validate in trainer, because it is rarely used and wired - Loader.load is expected to be a static method - Delete sth. in other_modules.py - Add more tests - Delete extra sample data
This commit is contained in:
parent
1d4e406e6f
commit
5be4cb7bb5
@ -70,18 +70,18 @@ class DataSet(list):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name="", instances=None, loader=None):
|
||||
def __init__(self, name="", instances=None, load_func=None):
|
||||
"""
|
||||
|
||||
:param name: str, the name of the dataset. (default: "")
|
||||
:param instances: list of Instance objects. (default: None)
|
||||
|
||||
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
|
||||
"""
|
||||
list.__init__([])
|
||||
self.name = name
|
||||
if instances is not None:
|
||||
self.extend(instances)
|
||||
self.dataset_loader = loader
|
||||
self.data_set_load_func = load_func
|
||||
|
||||
def index_all(self, vocab):
|
||||
for ins in self:
|
||||
@ -117,15 +117,15 @@ class DataSet(list):
|
||||
return lengths
|
||||
|
||||
def convert(self, data):
|
||||
"""Convert lists of strings into Instances with Fields"""
|
||||
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_with_vocabs(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting."""
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_for_infer(self, data, vocabs):
|
||||
"""Convert lists of strings into Instances with Fields."""
|
||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""
|
||||
|
||||
def load(self, data_path, vocabs=None, infer=False):
|
||||
"""Load data from the given files.
|
||||
@ -135,7 +135,7 @@ class DataSet(list):
|
||||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.
|
||||
|
||||
"""
|
||||
raw_data = self.dataset_loader.load(data_path)
|
||||
raw_data = self.data_set_load_func(data_path)
|
||||
if infer is True:
|
||||
self.convert_for_infer(raw_data, vocabs)
|
||||
else:
|
||||
@ -145,7 +145,7 @@ class DataSet(list):
|
||||
self.convert(raw_data)
|
||||
|
||||
def load_raw(self, raw_data, vocabs):
|
||||
"""
|
||||
"""Load raw data without loader. Used in FastNLP class.
|
||||
|
||||
:param raw_data:
|
||||
:param vocabs:
|
||||
@ -174,8 +174,8 @@ class DataSet(list):
|
||||
|
||||
|
||||
class SeqLabelDataSet(DataSet):
|
||||
def __init__(self, instances=None, loader=POSDataSetLoader()):
|
||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader)
|
||||
def __init__(self, instances=None, load_func=POSDataSetLoader().load):
|
||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary()
|
||||
|
||||
@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet):
|
||||
|
||||
|
||||
class TextClassifyDataSet(DataSet):
|
||||
def __init__(self, instances=None, loader=ClassDataSetLoader()):
|
||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader)
|
||||
def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
|
||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
|
||||
self.word_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary(need_default=False)
|
||||
|
||||
@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target):
|
||||
for inst in data_set:
|
||||
inst.fields[field_name].is_target = new_target
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_set = SeqLabelDataSet()
|
||||
data_set.load("../../test/data_for_tests/people.txt")
|
||||
a, b = data_set.split(0.3)
|
||||
print(type(data_set), type(a), type(b))
|
||||
print(len(data_set), len(a), len(b))
|
||||
|
@ -78,6 +78,7 @@ class Preprocessor(object):
|
||||
is only available when label_is_seq is True. Default: False.
|
||||
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
|
||||
"""
|
||||
print("Preprocessor is about to deprecate. Please use DataSet class.")
|
||||
self.data_vocab = Vocabulary()
|
||||
if label_is_seq is True:
|
||||
if share_vocab is True:
|
||||
@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor):
|
||||
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.")
|
||||
super(ClassPreprocess, self).__init__()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = Preprocessor()
|
||||
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
|
||||
[["You", "are", "pretty", "."], "1"]
|
||||
]
|
||||
training_set = p.run(train_dev_data)
|
||||
print(training_set)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
@ -178,31 +177,6 @@ class Trainer(object):
|
||||
logger.info(print_output)
|
||||
step += 1
|
||||
|
||||
def cross_validate(self, network, train_data_cv, dev_data_cv):
|
||||
"""Training with cross validation.
|
||||
|
||||
:param network: the model
|
||||
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
|
||||
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
|
||||
|
||||
"""
|
||||
if len(train_data_cv) != len(dev_data_cv):
|
||||
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv),
|
||||
len(dev_data_cv)))
|
||||
raise RuntimeError("the number of folds in train and dev data unequals")
|
||||
if self.validate is False:
|
||||
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ")
|
||||
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ")
|
||||
self.validate = True
|
||||
|
||||
n_fold = len(train_data_cv)
|
||||
logger.info("perform {} folds cross validation.".format(n_fold))
|
||||
for i in range(n_fold):
|
||||
print("CV:", i)
|
||||
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold))
|
||||
network_copy = copy.deepcopy(network)
|
||||
self.train(network_copy, train_data_cv[i], dev_data_cv[i])
|
||||
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
|
||||
|
||||
"""
|
||||
mapping from model name to [URL, file_name.class_name, model_pickle_name]
|
||||
@ -73,7 +72,7 @@ class FastNLP(object):
|
||||
:param model_dir: this directory should contain the following files:
|
||||
1. a trained model
|
||||
2. a config file, which is a fastNLP's configuration.
|
||||
3. a Vocab file, which is a pickle object of a Vocab instance.
|
||||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
@ -192,7 +191,7 @@ class FastNLP(object):
|
||||
|
||||
|
||||
def _load(self, model_dir, model_name):
|
||||
# To do
|
||||
|
||||
return 0
|
||||
|
||||
def _download(self, model_name, url):
|
||||
@ -202,7 +201,7 @@ class FastNLP(object):
|
||||
:param url:
|
||||
"""
|
||||
print("Downloading {} from {}".format(model_name, url))
|
||||
# To do
|
||||
# TODO: download model via url
|
||||
|
||||
def model_exist(self, model_dir):
|
||||
"""
|
||||
|
@ -3,12 +3,14 @@ class BaseLoader(object):
|
||||
def __init__(self):
|
||||
super(BaseLoader, self).__init__()
|
||||
|
||||
def load_lines(self, data_path):
|
||||
@staticmethod
|
||||
def load_lines(data_path):
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = f.readlines()
|
||||
return [line.strip() for line in text]
|
||||
|
||||
def load(self, data_path):
|
||||
@staticmethod
|
||||
def load(data_path):
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
text = f.readlines()
|
||||
return [[word for word in sent.strip()] for sent in text]
|
||||
|
@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader):
|
||||
def __init__(self):
|
||||
super(TokenizeDataSetLoader, self).__init__()
|
||||
|
||||
def load(self, data_path, max_seq_len=32):
|
||||
@staticmethod
|
||||
def load(data_path, max_seq_len=32):
|
||||
"""
|
||||
load pku dataset for Chinese word segmentation
|
||||
CWS (Chinese Word Segmentation) pku training dataset format:
|
||||
|
@ -196,30 +196,3 @@ class BiAffine(nn.Module):
|
||||
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, x, y):
|
||||
super(Transpose, self).__init__()
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(self.x, self.y)
|
||||
|
||||
|
||||
class WordDropout(nn.Module):
|
||||
def __init__(self, dropout_rate, drop_to_token):
|
||||
super(WordDropout, self).__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
self.drop_to_token = drop_to_token
|
||||
|
||||
def forward(self, word_idx):
|
||||
if not self.training:
|
||||
return word_idx
|
||||
drop_mask = torch.rand(word_idx.shape) < self.dropout_rate
|
||||
if word_idx.device.type == 'cuda':
|
||||
drop_mask = drop_mask.cuda()
|
||||
drop_mask = drop_mask.long()
|
||||
output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx
|
||||
return output
|
||||
|
@ -104,7 +104,8 @@ class ConfigSaver(object):
|
||||
:return:
|
||||
"""
|
||||
section_file = self._get_section(section_name)
|
||||
if len(section_file.__dict__.keys()) == 0:#the section not in file before
|
||||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before
|
||||
# append this section to config file
|
||||
with open(self.file_path, 'a') as f:
|
||||
f.write('[' + section_name + ']\n')
|
||||
for k in section.__dict__.keys():
|
||||
@ -114,9 +115,11 @@ class ConfigSaver(object):
|
||||
else:
|
||||
f.write(str(section[k]) + '\n\n')
|
||||
else:
|
||||
# the section exists
|
||||
change_file = False
|
||||
for k in section.__dict__.keys():
|
||||
if k not in section_file:
|
||||
# find a new key in this section
|
||||
change_file = True
|
||||
break
|
||||
if section_file[k] != section[k]:
|
||||
|
243
test/core/test_dataset.py
Normal file
243
test/core/test_dataset.py
Normal file
@ -0,0 +1,243 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
|
||||
from fastNLP.core.dataset import create_dataset_from_lists
|
||||
|
||||
|
||||
class TestDataSet(unittest.TestCase):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
|
||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
|
||||
|
||||
def test_case_1(self):
|
||||
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True,
|
||||
label_vocab=self.label_vocab)
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("label_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1])
|
||||
self.assertEqual(data_set[0].fields["label_seq"]._index,
|
||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
|
||||
|
||||
def test_case_2(self):
|
||||
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.unlabeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.unlabeled_data_list[0]])
|
||||
|
||||
|
||||
class TestDataSetConvertion(unittest.TestCase):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
|
||||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
|
||||
|
||||
def test_case_1(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path")
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
|
||||
self.assertTrue("truth" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
def test_case_2(self):
|
||||
def loader(path):
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
return unlabeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
def test_case_3(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = SeqLabelDataSet(load_func=loader)
|
||||
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("truth" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
|
||||
self.assertEqual(data_set[0].fields["truth"]._index,
|
||||
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
|
||||
|
||||
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
|
||||
|
||||
|
||||
class TestDataSetConvertionHHH(unittest.TestCase):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
|
||||
label_vocab = {"A": 1, "B": 2, "C": 3}
|
||||
|
||||
def test_case_1(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx")
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
|
||||
self.assertTrue("label" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
|
||||
|
||||
def test_case_2(self):
|
||||
def loader(path):
|
||||
labeled_data_list = [
|
||||
[["a", "b", "e", "d"], "A"],
|
||||
[["a", "b", "e", "d"], "C"],
|
||||
[["a", "b", "e", "d"], "B"],
|
||||
]
|
||||
return labeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
||||
|
||||
self.assertTrue("label" in data_set[0].fields)
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
|
||||
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]])
|
||||
|
||||
def test_case_3(self):
|
||||
def loader(path):
|
||||
unlabeled_data_list = [
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"],
|
||||
["a", "b", "e", "d"]
|
||||
]
|
||||
return unlabeled_data_list
|
||||
|
||||
data_set = TextClassifyDataSet(load_func=loader)
|
||||
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True)
|
||||
|
||||
self.assertEqual(len(data_set), len(self.labeled_data_list))
|
||||
self.assertTrue(len(data_set) > 0)
|
||||
self.assertTrue(hasattr(data_set[0], "fields"))
|
||||
self.assertTrue("word_seq" in data_set[0].fields)
|
||||
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
|
||||
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
|
||||
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
|
||||
self.assertEqual(data_set[0].fields["word_seq"]._index,
|
||||
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
|
@ -1,13 +1,13 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase):
|
||||
predictor = Predictor("./save/", pre.text_classify_post_processor)
|
||||
|
||||
# Load infer data
|
||||
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
|
||||
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load)
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase):
|
||||
model = SeqLabeling(model_args)
|
||||
predictor = Predictor("./save/", pre.seq_label_post_processor)
|
||||
|
||||
infer_data_set = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load)
|
||||
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
|
||||
|
||||
results = predictor.predict(network=model, data=infer_data_set)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -53,7 +53,7 @@ def infer():
|
||||
print("model loaded!")
|
||||
|
||||
# Data Loader
|
||||
infer_data = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
|
||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
|
||||
print("data set prepared")
|
||||
|
||||
|
@ -37,7 +37,7 @@ def infer():
|
||||
print("model loaded!")
|
||||
|
||||
# Load infer data
|
||||
infer_data = SeqLabelDataSet(loader=BaseLoader())
|
||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
|
||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
|
||||
|
||||
# inference
|
||||
@ -52,7 +52,7 @@ def train_test():
|
||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args})
|
||||
|
||||
# define dataset
|
||||
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
|
||||
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
|
||||
data_train.load(cws_data_path)
|
||||
train_args["vocab_size"] = len(data_train.word_vocab)
|
||||
train_args["num_classes"] = len(data_train.label_vocab)
|
||||
|
@ -40,7 +40,7 @@ def infer():
|
||||
print("vocabulary size:", len(word_vocab))
|
||||
print("number of classes:", len(label_vocab))
|
||||
|
||||
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader())
|
||||
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
|
||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})
|
||||
|
||||
model_args = ConfigSection()
|
||||
@ -67,7 +67,7 @@ def train():
|
||||
|
||||
# load dataset
|
||||
print("Loading data...")
|
||||
data = TextClassifyDataSet(loader=ClassDataSetLoader())
|
||||
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
|
||||
data.load(train_data_dir)
|
||||
|
||||
print("vocabulary size:", len(data.word_vocab))
|
||||
|
@ -2,7 +2,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear
|
||||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
|
||||
|
||||
|
||||
class TestGroupNorm(unittest.TestCase):
|
||||
@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase):
|
||||
y = bl(x_left, x_right)
|
||||
print(bl)
|
||||
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)
|
||||
|
||||
|
||||
class TestBiAffine(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
batch_size = 16
|
||||
encoder_length = 21
|
||||
decoder_length = 32
|
||||
layer = BiAffine(10, 10, 25, biaffine=True)
|
||||
decoder_input = torch.randn((batch_size, encoder_length, 10))
|
||||
encoder_input = torch.randn((batch_size, decoder_length, 10))
|
||||
y = layer(decoder_input, encoder_input)
|
||||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length))
|
||||
|
||||
def test_case_2(self):
|
||||
batch_size = 16
|
||||
encoder_length = 21
|
||||
decoder_length = 32
|
||||
layer = BiAffine(10, 10, 25, biaffine=False)
|
||||
decoder_input = torch.randn((batch_size, encoder_length, 10))
|
||||
encoder_input = torch.randn((batch_size, decoder_length, 10))
|
||||
y = layer(decoder_input, encoder_input)
|
||||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))
|
||||
|
@ -1,8 +1,5 @@
|
||||
import os
|
||||
|
||||
import unittest
|
||||
import configparser
|
||||
import json
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.saver.config_saver import ConfigSaver
|
||||
@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver
|
||||
|
||||
class TestConfigSaver(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
config_file_dir = "./test/loader/"
|
||||
config_file_dir = "test/loader/"
|
||||
config_file_name = "config"
|
||||
config_file_path = os.path.join(config_file_dir, config_file_name)
|
||||
|
||||
@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase):
|
||||
tmp_config_saver = ConfigSaver("file-NOT-exist")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def test_case_2(self):
|
||||
config = "[section_A]\n[section_B]\n"
|
||||
|
||||
with open("./test.cfg", "w", encoding="utf-8") as f:
|
||||
f.write(config)
|
||||
saver = ConfigSaver("./test.cfg")
|
||||
|
||||
section = ConfigSection()
|
||||
section["doubles"] = 0.8
|
||||
section["tt"] = [1, 2, 3]
|
||||
section["test"] = 105
|
||||
section["str"] = "this is a str"
|
||||
|
||||
saver.save_config_file("section_A", section)
|
||||
|
||||
os.system("rm ./test.cfg")
|
||||
|
||||
def test_case_3(self):
|
||||
config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n"
|
||||
|
||||
with open("./test.cfg", "w", encoding="utf-8") as f:
|
||||
f.write(config)
|
||||
saver = ConfigSaver("./test.cfg")
|
||||
|
||||
section = ConfigSection()
|
||||
section["doubles"] = 0.8
|
||||
section["tt"] = [1, 2, 3]
|
||||
section["test"] = 105
|
||||
section["str"] = "this is a str"
|
||||
|
||||
saver.save_config_file("section_A", section)
|
||||
|
||||
os.system("rm ./test.cfg")
|
||||
|
Loading…
Reference in New Issue
Block a user