Merge Preprocessor into DataSet.

- DataSet's __init__ takes a function as argument, rather than class object
- Preprocessor is about to remove. Don't use anymore.
- Remove cross_validate in trainer, because it is rarely used and wired
- Loader.load is expected to be a static method
- Delete sth. in other_modules.py
- Add more tests
- Delete extra sample data
This commit is contained in:
FengZiYjun 2018-10-01 16:49:54 +08:00
parent 1d4e406e6f
commit 5be4cb7bb5
17 changed files with 337 additions and 17872 deletions

View File

@ -70,18 +70,18 @@ class DataSet(list):
"""
def __init__(self, name="", instances=None, loader=None):
def __init__(self, name="", instances=None, load_func=None):
"""
:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)
:param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
"""
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)
self.dataset_loader = loader
self.data_set_load_func = load_func
def index_all(self, vocab):
for ins in self:
@ -117,15 +117,15 @@ class DataSet(list):
return lengths
def convert(self, data):
"""Convert lists of strings into Instances with Fields"""
"""Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
raise NotImplementedError
def convert_with_vocabs(self, data, vocabs):
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting."""
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
raise NotImplementedError
def convert_for_infer(self, data, vocabs):
"""Convert lists of strings into Instances with Fields."""
"""Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""
def load(self, data_path, vocabs=None, infer=False):
"""Load data from the given files.
@ -135,7 +135,7 @@ class DataSet(list):
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.
"""
raw_data = self.dataset_loader.load(data_path)
raw_data = self.data_set_load_func(data_path)
if infer is True:
self.convert_for_infer(raw_data, vocabs)
else:
@ -145,7 +145,7 @@ class DataSet(list):
self.convert(raw_data)
def load_raw(self, raw_data, vocabs):
"""
"""Load raw data without loader. Used in FastNLP class.
:param raw_data:
:param vocabs:
@ -174,8 +174,8 @@ class DataSet(list):
class SeqLabelDataSet(DataSet):
def __init__(self, instances=None, loader=POSDataSetLoader()):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader)
def __init__(self, instances=None, load_func=POSDataSetLoader().load):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary()
@ -231,8 +231,8 @@ class SeqLabelDataSet(DataSet):
class TextClassifyDataSet(DataSet):
def __init__(self, instances=None, loader=ClassDataSetLoader()):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader)
def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary(need_default=False)
@ -285,10 +285,3 @@ def change_field_is_target(data_set, field_name, new_target):
for inst in data_set:
inst.fields[field_name].is_target = new_target
if __name__ == "__main__":
data_set = SeqLabelDataSet()
data_set.load("../../test/data_for_tests/people.txt")
a, b = data_set.split(0.3)
print(type(data_set), type(a), type(b))
print(len(data_set), len(a), len(b))

View File

@ -78,6 +78,7 @@ class Preprocessor(object):
is only available when label_is_seq is True. Default: False.
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
"""
print("Preprocessor is about to deprecate. Please use DataSet class.")
self.data_vocab = Vocabulary()
if label_is_seq is True:
if share_vocab is True:
@ -307,11 +308,3 @@ class ClassPreprocess(Preprocessor):
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.")
super(ClassPreprocess, self).__init__()
if __name__ == "__main__":
p = Preprocessor()
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
[["You", "are", "pretty", "."], "1"]
]
training_set = p.run(train_dev_data)
print(training_set)

View File

@ -1,4 +1,3 @@
import copy
import os
import time
from datetime import timedelta
@ -178,31 +177,6 @@ class Trainer(object):
logger.info(print_output)
step += 1
def cross_validate(self, network, train_data_cv, dev_data_cv):
"""Training with cross validation.
:param network: the model
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
"""
if len(train_data_cv) != len(dev_data_cv):
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv),
len(dev_data_cv)))
raise RuntimeError("the number of folds in train and dev data unequals")
if self.validate is False:
logger.warn("Cross validation requires self.validate to be True. Please turn it on. ")
print("[warning] Cross validation requires self.validate to be True. Please turn it on. ")
self.validate = True
n_fold = len(train_data_cv)
logger.info("perform {} folds cross validation.".format(n_fold))
for i in range(n_fold):
print("CV:", i)
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold))
network_copy = copy.deepcopy(network)
self.train(network_copy, train_data_cv[i], dev_data_cv[i])
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

View File

@ -1,11 +1,10 @@
import os
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
"""
mapping from model name to [URL, file_name.class_name, model_pickle_name]
@ -73,7 +72,7 @@ class FastNLP(object):
:param model_dir: this directory should contain the following files:
1. a trained model
2. a config file, which is a fastNLP's configuration.
3. a Vocab file, which is a pickle object of a Vocab instance.
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
"""
self.model_dir = model_dir
self.model = None
@ -192,7 +191,7 @@ class FastNLP(object):
def _load(self, model_dir, model_name):
# To do
return 0
def _download(self, model_name, url):
@ -202,7 +201,7 @@ class FastNLP(object):
:param url:
"""
print("Downloading {} from {}".format(model_name, url))
# To do
# TODO: download model via url
def model_exist(self, model_dir):
"""

View File

@ -3,12 +3,14 @@ class BaseLoader(object):
def __init__(self):
super(BaseLoader, self).__init__()
def load_lines(self, data_path):
@staticmethod
def load_lines(data_path):
with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines()
return [line.strip() for line in text]
def load(self, data_path):
@staticmethod
def load(data_path):
with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines()
return [[word for word in sent.strip()] for sent in text]

View File

@ -84,7 +84,8 @@ class TokenizeDataSetLoader(DataSetLoader):
def __init__(self):
super(TokenizeDataSetLoader, self).__init__()
def load(self, data_path, max_seq_len=32):
@staticmethod
def load(data_path, max_seq_len=32):
"""
load pku dataset for Chinese word segmentation
CWS (Chinese Word Segmentation) pku training dataset format:

View File

@ -196,30 +196,3 @@ class BiAffine(nn.Module):
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2)
return output
class Transpose(nn.Module):
def __init__(self, x, y):
super(Transpose, self).__init__()
self.x = x
self.y = y
def forward(self, x):
return x.transpose(self.x, self.y)
class WordDropout(nn.Module):
def __init__(self, dropout_rate, drop_to_token):
super(WordDropout, self).__init__()
self.dropout_rate = dropout_rate
self.drop_to_token = drop_to_token
def forward(self, word_idx):
if not self.training:
return word_idx
drop_mask = torch.rand(word_idx.shape) < self.dropout_rate
if word_idx.device.type == 'cuda':
drop_mask = drop_mask.cuda()
drop_mask = drop_mask.long()
output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx
return output

View File

@ -104,7 +104,8 @@ class ConfigSaver(object):
:return:
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0:#the section not in file before
if len(section_file.__dict__.keys()) == 0: # the section not in the file before
# append this section to config file
with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n')
for k in section.__dict__.keys():
@ -114,9 +115,11 @@ class ConfigSaver(object):
else:
f.write(str(section[k]) + '\n\n')
else:
# the section exists
change_file = False
for k in section.__dict__.keys():
if k not in section_file:
# find a new key in this section
change_file = True
break
if section_file[k] != section[k]:

243
test/core/test_dataset.py Normal file
View File

@ -0,0 +1,243 @@
import unittest
from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
from fastNLP.core.dataset import create_dataset_from_lists
class TestDataSet(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
def test_case_1(self):
data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True,
label_vocab=self.label_vocab)
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
self.assertTrue("label_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index"))
self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["label_seq"]._index,
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
def test_case_2(self):
data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False)
self.assertEqual(len(data_set), len(self.unlabeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.unlabeled_data_list[0]])
class TestDataSetConvertion(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}
def test_case_1(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
return labeled_data_list
data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path")
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertTrue("truth" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
def test_case_2(self):
def loader(path):
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
return unlabeled_data_list
data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True)
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
def test_case_3(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
return labeled_data_list
data_set = SeqLabelDataSet(load_func=loader)
data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
self.assertTrue("truth" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["truth"], "text"))
self.assertTrue(hasattr(data_set[0].fields["truth"], "_index"))
self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["truth"]._index,
[self.label_vocab[c] for c in self.labeled_data_list[0][1]])
self.assertTrue("word_seq_origin_len" in data_set[0].fields)
class TestDataSetConvertionHHH(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"A": 1, "B": 2, "C": 3}
def test_case_1(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
return labeled_data_list
data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx")
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertTrue("label" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
def test_case_2(self):
def loader(path):
labeled_data_list = [
[["a", "b", "e", "d"], "A"],
[["a", "b", "e", "d"], "C"],
[["a", "b", "e", "d"], "B"],
]
return labeled_data_list
data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab})
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])
self.assertTrue("label" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["label"], "label"))
self.assertTrue(hasattr(data_set[0].fields["label"], "_index"))
self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1])
self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]])
def test_case_3(self):
def loader(path):
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
return unlabeled_data_list
data_set = TextClassifyDataSet(load_func=loader)
data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True)
self.assertEqual(len(data_set), len(self.labeled_data_list))
self.assertTrue(len(data_set) > 0)
self.assertTrue(hasattr(data_set[0], "fields"))
self.assertTrue("word_seq" in data_set[0].fields)
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text"))
self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index"))
self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0])
self.assertEqual(data_set[0].fields["word_seq"]._index,
[self.word_vocab[c] for c in self.labeled_data_list[0][0]])

View File

@ -1,13 +1,13 @@
import os
import unittest
from fastNLP.core.predictor import Predictor
from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.base_loader import BaseLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.models.sequence_modeling import SeqLabeling
class TestPredictor(unittest.TestCase):
@ -42,7 +42,7 @@ class TestPredictor(unittest.TestCase):
predictor = Predictor("./save/", pre.text_classify_post_processor)
# Load infer data
infer_data_set = TextClassifyDataSet(loader=BaseLoader())
infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load)
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
results = predictor.predict(network=model, data=infer_data_set)
@ -59,7 +59,7 @@ class TestPredictor(unittest.TestCase):
model = SeqLabeling(model_args)
predictor = Predictor("./save/", pre.seq_label_post_processor)
infer_data_set = SeqLabelDataSet(loader=BaseLoader())
infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
results = predictor.predict(network=model, data=infer_data_set)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -53,7 +53,7 @@ def infer():
print("model loaded!")
# Data Loader
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
print("data set prepared")

View File

@ -37,7 +37,7 @@ def infer():
print("model loaded!")
# Load infer data
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
# inference
@ -52,7 +52,7 @@ def train_test():
ConfigLoader().load_config(config_path, {"POS_infer": train_args})
# define dataset
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
data_train.load(cws_data_path)
train_args["vocab_size"] = len(data_train.word_vocab)
train_args["num_classes"] = len(data_train.label_vocab)

View File

@ -40,7 +40,7 @@ def infer():
print("vocabulary size:", len(word_vocab))
print("number of classes:", len(label_vocab))
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader())
infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})
model_args = ConfigSection()
@ -67,7 +67,7 @@ def train():
# load dataset
print("Loading data...")
data = TextClassifyDataSet(loader=ClassDataSetLoader())
data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
data.load(train_data_dir)
print("vocabulary size:", len(data.word_vocab))

View File

@ -2,7 +2,7 @@ import unittest
import torch
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
class TestGroupNorm(unittest.TestCase):
@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase):
y = bl(x_left, x_right)
print(bl)
bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)
class TestBiAffine(unittest.TestCase):
def test_case_1(self):
batch_size = 16
encoder_length = 21
decoder_length = 32
layer = BiAffine(10, 10, 25, biaffine=True)
decoder_input = torch.randn((batch_size, encoder_length, 10))
encoder_input = torch.randn((batch_size, decoder_length, 10))
y = layer(decoder_input, encoder_input)
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length))
def test_case_2(self):
batch_size = 16
encoder_length = 21
decoder_length = 32
layer = BiAffine(10, 10, 25, biaffine=False)
decoder_input = torch.randn((batch_size, encoder_length, 10))
encoder_input = torch.randn((batch_size, decoder_length, 10))
y = layer(decoder_input, encoder_input)
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))

View File

@ -1,8 +1,5 @@
import os
import unittest
import configparser
import json
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.config_saver import ConfigSaver
@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver
class TestConfigSaver(unittest.TestCase):
def test_case_1(self):
config_file_dir = "./test/loader/"
config_file_dir = "test/loader/"
config_file_name = "config"
config_file_path = os.path.join(config_file_dir, config_file_name)
@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase):
tmp_config_saver = ConfigSaver("file-NOT-exist")
except Exception as e:
pass
def test_case_2(self):
config = "[section_A]\n[section_B]\n"
with open("./test.cfg", "w", encoding="utf-8") as f:
f.write(config)
saver = ConfigSaver("./test.cfg")
section = ConfigSection()
section["doubles"] = 0.8
section["tt"] = [1, 2, 3]
section["test"] = 105
section["str"] = "this is a str"
saver.save_config_file("section_A", section)
os.system("rm ./test.cfg")
def test_case_3(self):
config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n"
with open("./test.cfg", "w", encoding="utf-8") as f:
f.write(config)
saver = ConfigSaver("./test.cfg")
section = ConfigSection()
section["doubles"] = 0.8
section["tt"] = [1, 2, 3]
section["test"] = 105
section["str"] = "this is a str"
saver.save_config_file("section_A", section)
os.system("rm ./test.cfg")