增加RandomBatchSampler

This commit is contained in:
yh_cc 2022-05-02 23:08:50 +08:00
parent 4e1b74c4cb
commit 7d5ce620f4
25 changed files with 692 additions and 181 deletions

View File

@ -65,12 +65,16 @@ def _get_backend() -> str:
return catch_backend[0]
# 方式 (2)
for backend in CHECK_BACKEND:
if backend in sys.modules:
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
return backend
for key, module in sys.modules.items():
catch_backend = _check_module(module)
if catch_backend:
break
if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0]
return 'numpy'
@ -227,7 +231,7 @@ class Collator:
设置可以 pad field 默认 pad 为什么类型的 tensor
:param backend: 对于可以 pad field使用哪种 tensor支持 ['torch','jittor','paddle','numpy','raw', 'auto', None]
若为 auto 则在进行 pad 的时候会根据调用的环境决定其 backend
若为 auto 则在进行 pad 的时候会自动根据调用的环境决定其 backend
:return:
"""
assert backend in SUPPORTED_BACKENDS

View File

@ -74,7 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name):
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
else:
dtype == ele_dtype
dtype = ele_dtype
return dtype
@ -174,6 +174,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
"""
shapes = get_shape(batch_field)
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

View File

@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger):
raise e
finally:
self.on_train_end()
self.driver.barrier()
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
@ -441,6 +440,7 @@ class Trainer(TrainerEventTrigger):
"""
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
_own_callbacks.extend(self._custom_callbacks[None])
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
self._custom_callbacks[None] = []
if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0:

View File

@ -14,7 +14,7 @@ else:
from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
@ -106,33 +106,33 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "JittorDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "JittorDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -145,18 +145,17 @@ class JittorDataLoader:
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch idx
:return:
"""
return self.cur_batch_indices
def prepare_jittor_dataloader():
...

View File

@ -15,8 +15,9 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
from fastNLP.core.collators.collator import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
class _PaddleDataset(Dataset):
@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader):
if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset)
if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last)
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader):
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle")
# if collate_fn is not None:
# self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = Collator(backend="paddle")
@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader):
self.cur_batch_indices = indices
yield data
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "PaddleDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "PaddleDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader):
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
获取当前 batch idx
:return:
"""
@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader):
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
return_list: bool = True,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
num_workers: int = 0, use_buffer_reader: bool = True,

View File

@ -3,14 +3,14 @@ __all__ = [
'prepare_torch_dataloader'
]
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@ -76,6 +76,9 @@ class TorchDataLoader(DataLoader):
if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset)
if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@ -87,9 +90,6 @@ class TorchDataLoader(DataLoader):
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch")
# if collate_fn is not None and collate_fn is not default_collate:
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来
# self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = Collator(backend="torch")
else:
@ -112,31 +112,32 @@ class TorchDataLoader(DataLoader):
yield data
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "TorchDataLoader":
pad_fn:Callable=None) -> Collator:
"""
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
如果需要对某个 field 的内容进行特殊的调整请使用这个函数
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor']分别代表输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型 pad_val None 该值只能为 None numpy
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> "TorchDataLoader":
def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Ex::
@ -149,24 +150,15 @@ class TorchDataLoader(DataLoader):
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else:
raise ValueError(f"collate_fn is not fastnlp collator")
def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx
:return:
"""
return self.cur_batch_indices
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,

View File

@ -0,0 +1,16 @@
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开将idx打包为indices
:param func: 需要修饰的函数
:return:
"""
def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)
return wrapper

View File

@ -780,7 +780,7 @@ class DataSet:
self.collator.set_ignore(*field_names)
@property
def collator(self):
def collator(self) -> Collator:
if self._collator is None:
self._collator = Collator()
return self._collator

View File

@ -22,7 +22,7 @@ from fastNLP.core.utils import (
rank_zero_rm
)
from fastNLP.core.samplers import (
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
ReproducibleBatchSampler,
RandomSampler,
@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):
return self.model, model.forward
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
reproducible: bool = False):
r"""
根据输入的 dataloader 得到一个 支持分布式 distributed 可复现的 (reproducible) dataloader

View File

@ -22,7 +22,7 @@ from fastNLP.core.log import logger
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleSampler,
RandomBatchSampler,
ReproduceBatchSampler,
RandomSampler,
)
@ -345,7 +345,7 @@ class PaddleDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@ -476,7 +476,7 @@ class PaddleDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler

View File

@ -14,7 +14,7 @@ from fastNLP.core.utils import (
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger
@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
class TorchDriver(Driver):
@ -293,7 +293,7 @@ class TorchDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@ -407,7 +407,7 @@ class TorchDriver(Driver):
res.shuffle = True
else:
res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler

View File

@ -14,9 +14,10 @@ __all__ = [
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",
"RandomBatchSampler",
"ReproduceBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",
"RandomBatchSampler",
"re_instantiate_sampler"
]
@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler

View File

@ -1,5 +1,6 @@
__all__ = [
'BucketedBatchSampler',
"ReproduceBatchSampler",
"RandomBatchSampler"
]
@ -54,13 +55,13 @@ class ReproducibleBatchSampler:
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
class RandomBatchSampler(ReproducibleBatchSampler):
class ReproduceBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象RandomBatchSampler 将首先遍历一边该对象然后将迭代
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象ReproduceBatchSampler 将首先遍历一边该对象然后将迭代
出来的序号暂存起来使用时按照 batch_size batch 大小吐出序号列表
:param batch_size: 每个 batch 的大小是多少
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample 是否丢掉
@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False
def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
class RandomBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
"""
随机分 batch batch_sampler
:param dataset: 实现了 __len__ 方法的数据容器
:param batch_size: 每个 batch 的大小
:param shuffle: 如果为 True将不进行 shuffle实际上数据会以从长到短的方式输出
:param drop_last: 如果最后一个 batch sample 数量无法凑齐 batch_size 这么多是否需要丢掉
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
# 多卡的相关的参数
self.num_replicas = kwargs.get("num_replicas", 1)
self.rank = kwargs.get("rank", 0)
self.epoch = kwargs.get("epoch", -1)
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
# 是否处于iteration之间为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)
# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
assert num_replicas > 0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0 <= rank < num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
self.pad = pad
return self
def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True说明之前的还没结束只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = list(range(len(self.dataset)))
if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_indices = indices[_i:len(indices):self.old_num_replicas]
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
indices = list(chain(*batches))
indices = indices[self.num_consumed_samples:]
# 取出这个 rank
indices = indices[self.rank:len(indices):self.num_replicas]
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
batches = list(map(list, batches))
else:
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
_num_batches = len(indices) // self.batch_size
if _num_batches == 0:
batches = [indices]
else:
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)
assert sum(map(len, batches)) == self.num_left_samples
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch导致每个epoch都一样了
self.epoch -= 1
def batchify(self, indices, batch_size, seed):
"""
indices 分为 batches
:param sorted_indices: List[int]
:param batch_size: int
:param seed: int
:return: List[List[int]]
"""
# 实际的 bucket 大小
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
num_samples = 0
batches = []
while num_samples<len(indices):
batches.append(indices[num_samples:num_samples+batch_size])
num_samples += batch_size
return batches
def set_epoch(self, epoch):
self.epoch = epoch
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size
@property
def total_size(self):
"""
这个变量代表的含义是当前这个sampler会最终产生出的index数量包括了其它rank的因为replica和pad的原因这个值可能等于
大于或者小于len(dataset)
:return:
"""
return self.num_consumed_samples + self.num_replicas*self.num_left_samples
@property
def num_left_samples(self):
"""
返回当前 iteration 还有多少个 sample 结束表示的是当前 rank 的还剩多少
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据
:return:
"""
num_sampler_per_rank = self.total_size//self.num_replicas
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
return num_batches
def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}
return states
def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True那么 num_consumed_samples 一定是 0
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_replicas = states['num_replicas']
class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):

View File

@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""
:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序
:param seed: 随机数种子
:param kwargs: 用户不需要使用fastNLP 内部使用
"""
super(RandomSampler, self).__init__()
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed

View File

@ -21,7 +21,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
from ..dataloaders.utils import indice_collate_wrapper

View File

@ -6,7 +6,7 @@ import warnings
from dataclasses import is_dataclass
from copy import deepcopy
from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional
from time import sleep
@ -35,7 +35,6 @@ __all__ = [
'nullcontext',
'pretty_table_printer',
'Option',
'indice_collate_wrapper',
'deprecated',
'seq_len_to_mask',
'rank_zero_rm',
@ -513,24 +512,6 @@ class Option(dict):
self.update(state)
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开将idx打包为indices
:param func: 需要修饰的函数
:return:
"""
def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)
return wrapper
_emitted_deprecation_warnings = set()

View File

@ -2,7 +2,7 @@ import pytest
from pathlib import Path
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
res = PaddleSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler paddle.io.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
只会替换 Sampler RandomSampler否则会替换 batch_sampler ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

View File

@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

View File

@ -2,7 +2,7 @@ import pytest
from pathlib import Path
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
"""
建立一个 batch_sampler RandomBatchSampler dataloader
建立一个 batch_sampler ReproduceBatchSampler dataloader
"""
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last
),
@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
res = TorchSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else:
@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler torch.utils.data.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
只会替换 Sampler RandomSampler否则会替换 batch_sampler ReproduceBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

View File

@ -30,7 +30,7 @@ class SequenceDataSet:
def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerRandomBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerReproduceBatchSampler
# reproducible 是 True 和 False
# 需要 check 返回的 sampler 和 dataloader 都不同了

View File

@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler,
replace_sampler,
)
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
def test_replace_batch_sampler():
dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16

View File

@ -5,7 +5,7 @@ import pytest
from itertools import chain
from copy import deepcopy
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
# "sampler_type": "ReproduceBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader
# # 不改变 batch_size
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# # 改变 batch_size
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader
# # 不改变 batch_size
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
@ -511,3 +511,313 @@ class TestBucketedBatchSampler:
already_seen_set.update(batch)
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)
class TestRandomBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):
before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4
dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
already_generate_indices.update(batch)
# 1. 保存状态
state = re_batchsampler.state_dict()
# 2. 断点重训,继续训练
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)
# 改变 batch_size
after_batch_size = 3
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
for batch in re_batchsampler3:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break
# 再 save 不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()
for batch in re_batchsampler3: # consume all, 这样才能save
pass
already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break
state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False需要最终是所有 sample
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
for batch in re_batchsampler4:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
assert len(already_generate_indices) == len(dataset)
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):
# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致
# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()
# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
for batch in new_sampler:
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)
# 测试替换卡的数量。
num_replica = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
for batch in new_sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的forward数据
:return:
"""
batch_size = 6
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
shuffle=shuffle, drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)
# 测试保存之后再次保存
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break
states = sampler.state_dict()
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)