mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
from itertools import chain
|
|
|
|
import pytest
|
|
|
|
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
|
|
|
|
|
|
class DatasetWithVaryLength:
|
|
def __init__(self, num_of_data=100):
|
|
self.data = list(range(num_of_data))
|
|
|
|
def __getitem__(self, item):
|
|
return self.data[item]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
|
|
class TestUnrepeatedSampler:
|
|
@pytest.mark.parametrize('shuffle', [True, False])
|
|
def test_single(self, shuffle):
|
|
num_of_data = 100
|
|
data = DatasetWithVaryLength(num_of_data)
|
|
sampler = UnrepeatedRandomSampler(data, shuffle)
|
|
indexes = set(sampler)
|
|
assert indexes==set(range(num_of_data))
|
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3])
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
|
@pytest.mark.parametrize('shuffle', [False, True])
|
|
def test_multi(self, num_replicas, num_of_data, shuffle):
|
|
if num_replicas > num_of_data:
|
|
pytest.skip("num_replicas > num_of_data")
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
|
samplers = []
|
|
for i in range(num_replicas):
|
|
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
|
|
sampler.set_distributed(num_replicas, rank=i)
|
|
samplers.append(sampler)
|
|
|
|
indexes = list(chain(*samplers))
|
|
assert len(indexes) == num_of_data
|
|
indexes = set(indexes)
|
|
assert indexes==set(range(num_of_data))
|
|
|
|
|
|
class TestUnrepeatedSortedSampler:
|
|
def test_single(self):
|
|
num_of_data = 100
|
|
data = DatasetWithVaryLength(num_of_data)
|
|
sampler = UnrepeatedSortedSampler(data, length=data.data)
|
|
indexes = list(sampler)
|
|
assert indexes==list(range(num_of_data-1, -1, -1))
|
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3])
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
|
def test_multi(self, num_replicas, num_of_data):
|
|
if num_replicas > num_of_data:
|
|
pytest.skip("num_replicas > num_of_data")
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
|
samplers = []
|
|
for i in range(num_replicas):
|
|
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
|
|
sampler.set_distributed(num_replicas, rank=i)
|
|
samplers.append(sampler)
|
|
|
|
# 保证顺序是没乱的
|
|
for sampler in samplers:
|
|
prev_index = float('inf')
|
|
for index in sampler:
|
|
assert index <= prev_index
|
|
prev_index = index
|
|
|
|
indexes = list(chain(*samplers))
|
|
assert len(indexes) == num_of_data # 不同卡之间没有交叉
|
|
indexes = set(indexes)
|
|
assert indexes==set(range(num_of_data))
|
|
|
|
|
|
class TestUnrepeatedSequentialSampler:
|
|
def test_single(self):
|
|
num_of_data = 100
|
|
data = DatasetWithVaryLength(num_of_data)
|
|
sampler = UnrepeatedSequentialSampler(data, length=data.data)
|
|
indexes = list(sampler)
|
|
assert indexes==list(range(num_of_data))
|
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3])
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
|
def test_multi(self, num_replicas, num_of_data):
|
|
if num_replicas > num_of_data:
|
|
pytest.skip("num_replicas > num_of_data")
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
|
samplers = []
|
|
for i in range(num_replicas):
|
|
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
|
|
sampler.set_distributed(num_replicas, rank=i)
|
|
samplers.append(sampler)
|
|
|
|
# 保证顺序是没乱的
|
|
for sampler in samplers:
|
|
prev_index = float('-inf')
|
|
for index in sampler:
|
|
assert index>=prev_index
|
|
prev_index = index
|
|
|
|
indexes = list(chain(*samplers))
|
|
assert len(indexes) == num_of_data
|
|
indexes = set(indexes)
|
|
assert indexes == set(range(num_of_data))
|