From 3f4544759ddfd4569be034de811b366f1a6bb3cf Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 15 Sep 2018 20:39:51 +0800 Subject: [PATCH] add unittest of data, fix bug --- fastNLP/{data => core}/vocabulary.py | 22 ++------- test/core/test_field.py | 69 ++++++++++++++++++++++++++++ test/core/test_vocab.py | 35 ++++++++++++++ 3 files changed, 108 insertions(+), 18 deletions(-) rename fastNLP/{data => core}/vocabulary.py (79%) create mode 100644 test/core/test_field.py create mode 100644 test/core/test_vocab.py diff --git a/fastNLP/data/vocabulary.py b/fastNLP/core/vocabulary.py similarity index 79% rename from fastNLP/data/vocabulary.py rename to fastNLP/core/vocabulary.py index 3cff161b..baae3753 100644 --- a/fastNLP/data/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -36,9 +36,10 @@ class Vocabulary(object): self.update(w) else: # it's a word to be added - self.word2idx[word] = len(self) - if self.idx2word is not None: - self.idx2word = None + if word not in self.word2idx: + self.word2idx[word] = len(self) + if self.idx2word is not None: + self.idx2word = None def __getitem__(self, w): @@ -80,20 +81,5 @@ class Vocabulary(object): self.__dict__.update(state) self.idx2word = None -if __name__ == '__main__': - import _pickle as pickle - vocab = Vocabulary() - filename = 'vocab' - vocab.update(filename) - vocab.update([filename, ['a'], [['b']], ['c']]) - idx = vocab[filename] - print('{} {}'.format(vocab.to_word(idx), vocab[filename])) - with open(filename, 'wb') as f: - pickle.dump(vocab, f) - with open(filename, 'rb') as f: - vocab = pickle.load(f) - - print('{} {}'.format(vocab.to_word(idx), vocab[filename])) - print(vocab.word2idx) \ No newline at end of file diff --git a/test/core/test_field.py b/test/core/test_field.py new file mode 100644 index 00000000..7c1b6343 --- /dev/null +++ b/test/core/test_field.py @@ -0,0 +1,69 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + +import unittest +import torch +from fastNLP.data.field import TextField, LabelField +from fastNLP.data.instance import Instance +from fastNLP.data.dataset import DataSet +from fastNLP.data.batch import Batch + + + +class TestField(unittest.TestCase): + def check_batched_data_equal(self, data1, data2): + self.assertEqual(len(data1), len(data2)) + for i in range(len(data1)): + self.assertTrue(data1[i].keys(), data2[i].keys()) + for i in range(len(data1)): + for t1, t2 in zip(data1[i].values(), data2[i].values()): + self.assertTrue(torch.equal(t1, t2)) + + def test_batchiter(self): + texts = [ + "i am a cat", + "this is a test of new batch", + "haha" + ] + labels = [0, 1, 0] + + # prepare vocabulary + vocab = {} + for text in texts: + for tokens in text.split(): + if tokens not in vocab: + vocab[tokens] = len(vocab) + + # prepare input dataset + data = DataSet() + for text, label in zip(texts, labels): + x = TextField(text.split(), False) + y = LabelField(label, is_target=True) + ins = Instance(text=x, label=y) + data.append(ins) + + # use vocabulary to index data + data.index_field("text", vocab) + + # define naive sampler for batch class + class SeqSampler: + def __call__(self, dataset): + return list(range(len(dataset))) + + # use bacth to iterate dataset + batcher = Batch(data, SeqSampler(), 2) + TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] + TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] + for epoch in range(3): + test_x, test_y = [], [] + for batch_x, batch_y in batcher: + test_x.append(batch_x) + test_y.append(batch_y) + self.check_batched_data_equal(TRUE_X, test_x) + self.check_batched_data_equal(TRUE_Y, test_y) + + +if __name__ == "__main__": + unittest.main() + \ No newline at end of file diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py new file mode 100644 index 00000000..dd51c197 --- /dev/null +++ b/test/core/test_vocab.py @@ -0,0 +1,35 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + +import unittest +from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX + +class TestVocabulary(unittest.TestCase): + def test_vocab(self): + import _pickle as pickle + import os + vocab = Vocabulary() + filename = 'vocab' + vocab.update(filename) + vocab.update([filename, ['a'], [['b']], ['c']]) + idx = vocab[filename] + before_pic = (vocab.to_word(idx), vocab[filename]) + + with open(filename, 'wb') as f: + pickle.dump(vocab, f) + with open(filename, 'rb') as f: + vocab = pickle.load(f) + os.remove(filename) + + vocab.build_reverse_vocab() + after_pic = (vocab.to_word(idx), vocab[filename]) + TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} + TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) + TRUE_IDXDICT = {0: '', 1: '', 2: '', 3: '', 4: '', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} + self.assertEqual(before_pic, after_pic) + self.assertDictEqual(TRUE_DICT, vocab.word2idx) + self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file