添加了collators测试用例

This commit is contained in:
MorningForest 2022-04-08 21:35:39 +08:00
parent 41b0bae46b
commit 65c621db78

View File

@ -0,0 +1,81 @@
import pytest
from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.dataset import DataSet
class TestCollator:
@pytest.mark.parametrize('as_numpy', [True, False])
def test_auto_collator(self, as_numpy):
"""
测试auto_collator的auto_pad功能
:param as_numpy:
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=as_numpy)
collator.set_input('x', 'y')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
results = []
for bucket in bucket_data:
res = collator(bucket)
assert res['x'].shape == (40, 5)
assert res['y'].shape == (40,)
results.append(res)
def test_auto_collator_v1(self):
"""
测试auto_collator的set_pad_val和set_pad_val功能
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
collator.set_input('x')
collator.set_pad_val('x', val=-1)
collator.set_as_numpy(True)
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = collator(bucket)
print(res)
def test_multicollator(self):
"""
测试multicollator功能
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
multi_collator = _MultiCollator(collator)
multi_collator.set_as_numpy(as_numpy=True)
multi_collator.set_pad_val('x', val=-1)
multi_collator.set_input('x')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = multi_collator(bucket)
print(res)