from itertools import chain import pytest from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler class DatasetWithVaryLength: def __init__(self, num_of_data=100): self.data = list(range(num_of_data)) def __getitem__(self, item): return self.data[item] def __len__(self): return len(self.data) class TestUnrepeatedSampler: @pytest.mark.parametrize('shuffle', [True, False]) def test_single(self, shuffle): num_of_data = 100 data = DatasetWithVaryLength(num_of_data) sampler = UnrepeatedRandomSampler(data, shuffle) indexes = set(sampler) assert indexes==set(range(num_of_data)) @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('shuffle', [False, True]) def test_multi(self, num_replicas, num_of_data, shuffle): if num_replicas > num_of_data: pytest.skip("num_replicas > num_of_data") data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] for i in range(num_replicas): sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) indexes = list(chain(*samplers)) assert len(indexes) == num_of_data indexes = set(indexes) assert indexes==set(range(num_of_data)) class TestUnrepeatedSortedSampler: def test_single(self): num_of_data = 100 data = DatasetWithVaryLength(num_of_data) sampler = UnrepeatedSortedSampler(data, length=data.data) indexes = list(sampler) assert indexes==list(range(num_of_data-1, -1, -1)) @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replicas, num_of_data): if num_replicas > num_of_data: pytest.skip("num_replicas > num_of_data") data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] for i in range(num_replicas): sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) # 保证顺序是没乱的 for sampler in samplers: prev_index = float('inf') for index in sampler: assert index <= prev_index prev_index = index indexes = list(chain(*samplers)) assert len(indexes) == num_of_data # 不同卡之间没有交叉 indexes = set(indexes) assert indexes==set(range(num_of_data)) class TestUnrepeatedSequentialSampler: def test_single(self): num_of_data = 100 data = DatasetWithVaryLength(num_of_data) sampler = UnrepeatedSequentialSampler(data, length=data.data) indexes = list(sampler) assert indexes==list(range(num_of_data)) @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replicas, num_of_data): if num_replicas > num_of_data: pytest.skip("num_replicas > num_of_data") data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] for i in range(num_replicas): sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) sampler.set_distributed(num_replicas, rank=i) samplers.append(sampler) # 保证顺序是没乱的 for sampler in samplers: prev_index = float('-inf') for index in sampler: assert index>=prev_index prev_index = index indexes = list(chain(*samplers)) assert len(indexes) == num_of_data indexes = set(indexes) assert indexes == set(range(num_of_data))