fastNLP/tests/core/samplers/test_unrepeated_sampler.py

111 lines
3.9 KiB
Python

from itertools import chain
import pytest
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
class DatasetWithVaryLength:
def __init__(self, num_of_data=100):
self.data = list(range(num_of_data))
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
class TestUnrepeatedSampler:
@pytest.mark.parametrize('shuffle', [True, False])
def test_single(self, shuffle):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedRandomSampler(data, shuffle)
indexes = set(sampler)
assert indexes==set(range(num_of_data))
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replicas, num_of_data, shuffle):
if num_replicas > num_of_data:
pytest.skip("num_replicas > num_of_data")
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replicas):
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes==set(range(num_of_data))
class TestUnrepeatedSortedSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSortedSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replicas, num_of_data):
if num_replicas > num_of_data:
pytest.skip("num_replicas > num_of_data")
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replicas):
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)
# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('inf')
for index in sampler:
assert index <= prev_index
prev_index = index
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data # 不同卡之间没有交叉
indexes = set(indexes)
assert indexes==set(range(num_of_data))
class TestUnrepeatedSequentialSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSequentialSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data))
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replicas, num_of_data):
if num_replicas > num_of_data:
pytest.skip("num_replicas > num_of_data")
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replicas):
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)
# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('-inf')
for index in sampler:
assert index>=prev_index
prev_index = index
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes == set(range(num_of_data))