添加了 reproducible batch sampler 的测试

This commit is contained in:
YWMditto 2022-05-03 12:37:18 +08:00
parent 7d5ce620f4
commit 1528107480
5 changed files with 316 additions and 158 deletions

View File

@ -8,7 +8,6 @@ import math
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union, List from typing import Dict, Union, List
from itertools import chain from itertools import chain
import os
import numpy as np import numpy as np

View File

@ -1,7 +1,7 @@
from functools import reduce from functools import reduce
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
from tests.helpers.datasets.normal_data import NormalIterator from tests.helpers.datasets.normal_data import NormalSampler
class Test_WrapDataLoader: class Test_WrapDataLoader:
@ -9,7 +9,7 @@ class Test_WrapDataLoader:
def test_normal_generator(self): def test_normal_generator(self):
all_sanity_batches = [4, 20, 100] all_sanity_batches = [4, 20, 100]
for sanity_batches in all_sanity_batches: for sanity_batches in all_sanity_batches:
data = NormalIterator(num_of_data=1000) data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper(dataloader=data)) dataloader = iter(wrapper(dataloader=data))
mark = 0 mark = 0

View File

@ -1,161 +1,131 @@
from array import array
import numpy as np import numpy as np
import pytest import pytest
from itertools import chain from itertools import chain
from copy import deepcopy from copy import deepcopy
from array import array
from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
#
# class TestReproducibleBatchSampler: class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西 def test_1(self):
# def test_torch_dataloader_1(self): sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
# import torch
# from torch.utils.data import DataLoader reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
# # no shuffle
# before_batch_size = 7 forward_steps = 3
# dataset = TorchNormalDataset(num_of_data=100) iterator = iter(reproduce_batch_sampler)
# dataloader = DataLoader(dataset, batch_size=before_batch_size) i = 0
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) while i < forward_steps:
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) next(iterator)
# i += 1
# forward_steps = 3
# iter_dataloader = iter(dataloader) # 保存状态;
# for _ in range(forward_steps): state = reproduce_batch_sampler.state_dict()
# next(iter_dataloader)
# assert state == {"index_list": array("I", list(range(100))),
# # 1. 保存状态 "num_consumed_samples": forward_steps * 4,
# _get_re_batchsampler = dataloader.batch_sampler "sampler_type": "ReproduceBatchSampler"}
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict() # 重新生成一个 batchsampler 然后加载状态;
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
# "sampler_type": "ReproduceBatchSampler"} reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
# reproduce_batch_sampler.load_state_dict(state)
# # 2. 断点重训,重新生成一个 dataloader
# # 不改变 batch_size real_res = []
# dataloader = DataLoader(dataset, batch_size=before_batch_size) supposed_res = (list(range(12, 16)), list(range(16, 20)))
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) forward_steps = 2
# re_batchsampler.load_state_dict(state) iter_dataloader = iter(reproduce_batch_sampler)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) for i in range(forward_steps):
# forward_steps = 2 assert supposed_res[i] == real_res[i]
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps): # 改变 batchsize
# real_res.append(next(iter_dataloader)) sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
# reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
# for i in range(forward_steps): reproduce_batch_sampler.load_state_dict(state)
# assert all(real_res[i] == supposed_res[i])
# real_res = []
# # 改变 batch_size supposed_res = (list(range(12, 19)), list(range(19, 26)))
# after_batch_size = 3 forward_steps = 2
# dataloader = DataLoader(dataset, batch_size=after_batch_size) iter_dataloader = iter(reproduce_batch_sampler)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) for _ in range(forward_steps):
# re_batchsampler.load_state_dict(state) real_res.append(next(iter_dataloader))
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# for i in range(forward_steps):
# real_res = [] assert supposed_res[i] == real_res[i]
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2 # 断点重训的第二轮是否是一个完整的 dataloader
# iter_dataloader = iter(dataloader) # 先把断点重训所在的那一个 epoch 跑完;
# for _ in range(forward_steps): begin_idx = 26
# real_res.append(next(iter_dataloader)) while True:
# try:
# for i in range(forward_steps): data = next(iter_dataloader)
# assert all(real_res[i] == supposed_res[i]) _batch_size = len(data)
# assert data == list(range(begin_idx, begin_idx + _batch_size))
# # 断点重训的第二轮是否是一个完整的 dataloader begin_idx += _batch_size
# # 先把断点重训所在的那一个 epoch 跑完; except StopIteration:
# begin_idx = 27 break
# while True:
# try: # 开始新的一轮;
# data = next(iter_dataloader) begin_idx = 0
# _batch_size = len(data) iter_dataloader = iter(reproduce_batch_sampler)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) while True:
# begin_idx += _batch_size try:
# except StopIteration: data = next(iter_dataloader)
# break _batch_size = len(data)
# assert data == list(range(begin_idx, begin_idx + _batch_size))
# # 开始新的一轮; begin_idx += _batch_size
# begin_idx = 0 except StopIteration:
# iter_dataloader = iter(dataloader) break
# while True:
# try: def test_2(self):
# data = next(iter_dataloader)
# _batch_size = len(data) # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) before_batch_size = 7
# begin_idx += _batch_size sampler = NormalSampler(num_of_data=100)
# except StopIteration: # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# break reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
#
# def test_torch_dataloader_2(self): # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; all_supposed_data = []
# from torch.utils.data import DataLoader forward_steps = 3
# # no shuffle iter_dataloader = iter(reproduce_batch_sampler)
# before_batch_size = 7 for _ in range(forward_steps):
# dataset = TorchNormalDataset(num_of_data=100) all_supposed_data.extend(next(iter_dataloader))
# # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) # 1. 保存状态
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) state = reproduce_batch_sampler.state_dict()
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# # 2. 断点重训,重新生成一个 dataloader
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; # 不改变 batch_size
# all_supposed_data = [] sampler = NormalSampler(num_of_data=100, shuffle=True)
# forward_steps = 3 reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
# iter_dataloader = iter(dataloader) reproduce_batch_sampler.load_state_dict(state)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist()) # 先把这一轮的数据过完;
# pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
# # 1. 保存状态 iter_dataloader = iter(reproduce_batch_sampler)
# _get_re_batchsampler = dataloader.batch_sampler while True:
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) try:
# state = _get_re_batchsampler.state_dict() all_supposed_data.extend(next(iter_dataloader))
# except StopIteration:
# # 2. 断点重训,重新生成一个 dataloader break
# # 不改变 batch_size assert all_supposed_data == list(pre_index_list)
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) # 重新开启新的一轮;
# re_batchsampler.load_state_dict(state) for _ in range(3):
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) iter_dataloader = iter(reproduce_batch_sampler)
# res = []
# # 先把这一轮的数据过完; while True:
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] try:
# while True: res.extend(next(iter_dataloader))
# try: except StopIteration:
# all_supposed_data.extend(next(iter_dataloader).tolist()) break
# except StopIteration: assert res != all_supposed_data
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass
class DatasetWithVaryLength: class DatasetWithVaryLength:

View File

@ -0,0 +1,141 @@
from array import array
import torch
from torch.utils.data import DataLoader
import pytest
from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:
def test_torch_dataloader_1(self):
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "ReproduceBatchSampler"}
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 改变 batch_size
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 断点重训的第二轮是否是一个完整的 dataloader
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
iter_dataloader = iter(dataloader)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)
# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert res != all_supposed_data

View File

@ -1,13 +1,25 @@
import numpy as np import numpy as np
import random
class NormalIterator: class NormalSampler:
def __init__(self, num_of_data=1000): def __init__(self, num_of_data=1000, shuffle=False):
self._num_of_data = num_of_data self._num_of_data = num_of_data
self._data = list(range(num_of_data)) self._data = list(range(num_of_data))
if shuffle:
random.shuffle(self._data)
self.shuffle = shuffle
self._index = 0 self._index = 0
self.need_reinitialize = False
def __iter__(self): def __iter__(self):
if self.need_reinitialize:
self._index = 0
if self.shuffle:
random.shuffle(self._data)
else:
self.need_reinitialize = True
return self return self
def __next__(self): def __next__(self):
@ -15,12 +27,45 @@ class NormalIterator:
raise StopIteration raise StopIteration
_data = self._data[self._index] _data = self._data[self._index]
self._index += 1 self._index += 1
return self._data return _data
def __len__(self): def __len__(self):
return self._num_of_data return self._num_of_data
class NormalBatchSampler:
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
class RandomDataset: class RandomDataset:
def __init__(self, num_data=10): def __init__(self, num_data=10):
self.data = np.random.rand(num_data) self.data = np.random.rand(num_data)
@ -29,4 +74,7 @@ class RandomDataset:
return len(self.data) return len(self.data)
def __getitem__(self, item): def __getitem__(self, item):
return self.data[item] return self.data[item]