add unittest of data, fix bug

This commit is contained in:
yunfan 2018-09-15 20:39:51 +08:00
parent 466f3c21ec
commit 3f4544759d
3 changed files with 108 additions and 18 deletions

View File

@ -36,9 +36,10 @@ class Vocabulary(object):
self.update(w)
else:
# it's a word to be added
self.word2idx[word] = len(self)
if self.idx2word is not None:
self.idx2word = None
if word not in self.word2idx:
self.word2idx[word] = len(self)
if self.idx2word is not None:
self.idx2word = None
def __getitem__(self, w):
@ -80,20 +81,5 @@ class Vocabulary(object):
self.__dict__.update(state)
self.idx2word = None
if __name__ == '__main__':
import _pickle as pickle
vocab = Vocabulary()
filename = 'vocab'
vocab.update(filename)
vocab.update([filename, ['a'], [['b']], ['c']])
idx = vocab[filename]
print('{} {}'.format(vocab.to_word(idx), vocab[filename]))
with open(filename, 'wb') as f:
pickle.dump(vocab, f)
with open(filename, 'rb') as f:
vocab = pickle.load(f)
print('{} {}'.format(vocab.to_word(idx), vocab[filename]))
print(vocab.word2idx)

69
test/core/test_field.py Normal file
View File

@ -0,0 +1,69 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import unittest
import torch
from fastNLP.data.field import TextField, LabelField
from fastNLP.data.instance import Instance
from fastNLP.data.dataset import DataSet
from fastNLP.data.batch import Batch
class TestField(unittest.TestCase):
def check_batched_data_equal(self, data1, data2):
self.assertEqual(len(data1), len(data2))
for i in range(len(data1)):
self.assertTrue(data1[i].keys(), data2[i].keys())
for i in range(len(data1)):
for t1, t2 in zip(data1[i].values(), data2[i].values()):
self.assertTrue(torch.equal(t1, t2))
def test_batchiter(self):
texts = [
"i am a cat",
"this is a test of new batch",
"haha"
]
labels = [0, 1, 0]
# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text.split():
if tokens not in vocab:
vocab[tokens] = len(vocab)
# prepare input dataset
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text.split(), False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))
# use bacth to iterate dataset
batcher = Batch(data, SeqSampler(), 2)
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}]
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}]
for epoch in range(3):
test_x, test_y = [], []
for batch_x, batch_y in batcher:
test_x.append(batch_x)
test_y.append(batch_y)
self.check_batched_data_equal(TRUE_X, test_x)
self.check_batched_data_equal(TRUE_Y, test_y)
if __name__ == "__main__":
unittest.main()

35
test/core/test_vocab.py Normal file
View File

@ -0,0 +1,35 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import unittest
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
class TestVocabulary(unittest.TestCase):
def test_vocab(self):
import _pickle as pickle
import os
vocab = Vocabulary()
filename = 'vocab'
vocab.update(filename)
vocab.update([filename, ['a'], [['b']], ['c']])
idx = vocab[filename]
before_pic = (vocab.to_word(idx), vocab[filename])
with open(filename, 'wb') as f:
pickle.dump(vocab, f)
with open(filename, 'rb') as f:
vocab = pickle.load(f)
os.remove(filename)
vocab.build_reverse_vocab()
after_pic = (vocab.to_word(idx), vocab[filename])
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'}
self.assertEqual(before_pic, after_pic)
self.assertDictEqual(TRUE_DICT, vocab.word2idx)
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
if __name__ == '__main__':
unittest.main()