mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
add tests
This commit is contained in:
parent
5133fe67b4
commit
325157b53f
@ -48,8 +48,6 @@ def simple_sort_bucketing(lengths):
|
||||
"""
|
||||
|
||||
:param lengths: list of int, the lengths of all examples.
|
||||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
|
||||
threshold for each bucket (This is usually None.).
|
||||
:return data: 2-level list
|
||||
::
|
||||
|
||||
@ -75,6 +73,7 @@ def k_means_1d(x, k, max_iter=100):
|
||||
assignment: numpy array, 1-D, the bucket id assigned to each example.
|
||||
"""
|
||||
sorted_x = sorted(list(set(x)))
|
||||
x = np.array(x)
|
||||
if len(sorted_x) < k:
|
||||
raise ValueError("too few buckets")
|
||||
gap = len(sorted_x) / k
|
||||
@ -119,34 +118,3 @@ def k_means_bucketing(lengths, buckets):
|
||||
bucket_data[bucket_id].append(idx)
|
||||
return bucket_data
|
||||
|
||||
|
||||
class BucketSampler(BaseSampler):
|
||||
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
|
||||
In sampling, first random choose a bucket. Then sample data from it.
|
||||
The number of buckets is decided dynamically by the variance of sentence lengths.
|
||||
|
||||
"""
|
||||
|
||||
def __call__(self, data_set, batch_size, num_buckets):
|
||||
return self._process(data_set, batch_size, num_buckets)
|
||||
|
||||
def _process(self, data_set, batch_size, num_buckets, use_kmeans=False):
|
||||
"""
|
||||
|
||||
:param data_set: a DataSet object
|
||||
:param batch_size: int
|
||||
:param num_buckets: int, number of buckets for grouping these sequences.
|
||||
:param use_kmeans: bool, whether to use k-means to create buckets.
|
||||
|
||||
"""
|
||||
buckets = ([None] * num_buckets)
|
||||
if use_kmeans is True:
|
||||
buckets = k_means_bucketing(data_set, buckets)
|
||||
else:
|
||||
buckets = simple_sort_bucketing(data_set)
|
||||
index_list = []
|
||||
for _ in range(len(data_set) // batch_size):
|
||||
chosen_bucket = buckets[np.random.randint(0, len(buckets))]
|
||||
np.random.shuffle(chosen_bucket)
|
||||
index_list += [idx for idx in chosen_bucket[:batch_size]]
|
||||
return index_list
|
||||
|
@ -1,10 +1,10 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.field import CharTextField
|
||||
from fastNLP.core.field import CharTextField, LabelField, SeqLabelField
|
||||
|
||||
|
||||
class TestField(unittest.TestCase):
|
||||
def test_case(self):
|
||||
def test_char_field(self):
|
||||
text = "PhD applicants must submit a Research Plan and a resume " \
|
||||
"specify your class ranking written in English and a list of research" \
|
||||
" publications if any".split()
|
||||
@ -21,3 +21,22 @@ class TestField(unittest.TestCase):
|
||||
self.assertEqual(field.contents(), text)
|
||||
tensor = field.to_tensor(50)
|
||||
self.assertEqual(tuple(tensor.shape), (50, max_word_len))
|
||||
|
||||
def test_label_field(self):
|
||||
label = LabelField("A", is_target=True)
|
||||
self.assertEqual(label.get_length(), 1)
|
||||
self.assertEqual(label.index({"A": 10}), 10)
|
||||
|
||||
label = LabelField(30, is_target=True)
|
||||
self.assertEqual(label.get_length(), 1)
|
||||
tensor = label.to_tensor(0)
|
||||
self.assertEqual(tensor.shape, ())
|
||||
self.assertEqual(int(tensor), 30)
|
||||
|
||||
def test_seq_label_field(self):
|
||||
seq = ["a", "b", "c", "d", "a", "c", "a", "b"]
|
||||
field = SeqLabelField(seq)
|
||||
vocab = {"a": 10, "b": 20, "c": 30, "d": 40}
|
||||
self.assertEqual(field.index(vocab), [vocab[x] for x in seq])
|
||||
tensor = field.to_tensor(10)
|
||||
self.assertEqual(tuple(tensor.shape), (10,))
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler
|
||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \
|
||||
k_means_1d, k_means_bucketing, simple_sort_bucketing
|
||||
|
||||
|
||||
def test_convert_to_torch_tensor():
|
||||
@ -26,5 +27,18 @@ def test_random_sampler():
|
||||
assert d in data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sequential_sampler()
|
||||
def test_k_means():
|
||||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5)
|
||||
centroids, assign = list(centroids), list(assign)
|
||||
assert len(centroids) == 2
|
||||
assert len(assign) == 10
|
||||
|
||||
|
||||
def test_k_means_bucketing():
|
||||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None])
|
||||
assert len(res) == 2
|
||||
|
||||
|
||||
def test_simple_sort_bucketing():
|
||||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10])
|
||||
assert len(_) == 10
|
||||
|
Loading…
Reference in New Issue
Block a user