mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
add unittest of data, fix bug
This commit is contained in:
parent
466f3c21ec
commit
3f4544759d
@ -36,9 +36,10 @@ class Vocabulary(object):
|
||||
self.update(w)
|
||||
else:
|
||||
# it's a word to be added
|
||||
self.word2idx[word] = len(self)
|
||||
if self.idx2word is not None:
|
||||
self.idx2word = None
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = len(self)
|
||||
if self.idx2word is not None:
|
||||
self.idx2word = None
|
||||
|
||||
|
||||
def __getitem__(self, w):
|
||||
@ -80,20 +81,5 @@ class Vocabulary(object):
|
||||
self.__dict__.update(state)
|
||||
self.idx2word = None
|
||||
|
||||
if __name__ == '__main__':
|
||||
import _pickle as pickle
|
||||
vocab = Vocabulary()
|
||||
filename = 'vocab'
|
||||
vocab.update(filename)
|
||||
vocab.update([filename, ['a'], [['b']], ['c']])
|
||||
idx = vocab[filename]
|
||||
print('{} {}'.format(vocab.to_word(idx), vocab[filename]))
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(vocab, f)
|
||||
with open(filename, 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
|
||||
print('{} {}'.format(vocab.to_word(idx), vocab[filename]))
|
||||
print(vocab.word2idx)
|
||||
|
69
test/core/test_field.py
Normal file
69
test/core/test_field.py
Normal file
@ -0,0 +1,69 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
from fastNLP.data.field import TextField, LabelField
|
||||
from fastNLP.data.instance import Instance
|
||||
from fastNLP.data.dataset import DataSet
|
||||
from fastNLP.data.batch import Batch
|
||||
|
||||
|
||||
|
||||
class TestField(unittest.TestCase):
|
||||
def check_batched_data_equal(self, data1, data2):
|
||||
self.assertEqual(len(data1), len(data2))
|
||||
for i in range(len(data1)):
|
||||
self.assertTrue(data1[i].keys(), data2[i].keys())
|
||||
for i in range(len(data1)):
|
||||
for t1, t2 in zip(data1[i].values(), data2[i].values()):
|
||||
self.assertTrue(torch.equal(t1, t2))
|
||||
|
||||
def test_batchiter(self):
|
||||
texts = [
|
||||
"i am a cat",
|
||||
"this is a test of new batch",
|
||||
"haha"
|
||||
]
|
||||
labels = [0, 1, 0]
|
||||
|
||||
# prepare vocabulary
|
||||
vocab = {}
|
||||
for text in texts:
|
||||
for tokens in text.split():
|
||||
if tokens not in vocab:
|
||||
vocab[tokens] = len(vocab)
|
||||
|
||||
# prepare input dataset
|
||||
data = DataSet()
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text.split(), False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
def __call__(self, dataset):
|
||||
return list(range(len(dataset)))
|
||||
|
||||
# use bacth to iterate dataset
|
||||
batcher = Batch(data, SeqSampler(), 2)
|
||||
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}]
|
||||
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}]
|
||||
for epoch in range(3):
|
||||
test_x, test_y = [], []
|
||||
for batch_x, batch_y in batcher:
|
||||
test_x.append(batch_x)
|
||||
test_y.append(batch_y)
|
||||
self.check_batched_data_equal(TRUE_X, test_x)
|
||||
self.check_batched_data_equal(TRUE_Y, test_y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
35
test/core/test_vocab.py
Normal file
35
test/core/test_vocab.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
import unittest
|
||||
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
|
||||
|
||||
class TestVocabulary(unittest.TestCase):
|
||||
def test_vocab(self):
|
||||
import _pickle as pickle
|
||||
import os
|
||||
vocab = Vocabulary()
|
||||
filename = 'vocab'
|
||||
vocab.update(filename)
|
||||
vocab.update([filename, ['a'], [['b']], ['c']])
|
||||
idx = vocab[filename]
|
||||
before_pic = (vocab.to_word(idx), vocab[filename])
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(vocab, f)
|
||||
with open(filename, 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
os.remove(filename)
|
||||
|
||||
vocab.build_reverse_vocab()
|
||||
after_pic = (vocab.to_word(idx), vocab[filename])
|
||||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
|
||||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
|
||||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'}
|
||||
self.assertEqual(before_pic, after_pic)
|
||||
self.assertDictEqual(TRUE_DICT, vocab.word2idx)
|
||||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user