mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
增加RandomBatchSampler
This commit is contained in:
parent
4e1b74c4cb
commit
7d5ce620f4
@ -65,12 +65,16 @@ def _get_backend() -> str:
|
||||
return catch_backend[0]
|
||||
|
||||
# 方式 (2)
|
||||
for backend in CHECK_BACKEND:
|
||||
if backend in sys.modules:
|
||||
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
|
||||
return backend
|
||||
for key, module in sys.modules.items():
|
||||
catch_backend = _check_module(module)
|
||||
if catch_backend:
|
||||
break
|
||||
if len(catch_backend):
|
||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
|
||||
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
|
||||
return catch_backend[0]
|
||||
|
||||
return 'numpy'
|
||||
@ -227,7 +231,7 @@ class Collator:
|
||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor
|
||||
|
||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
|
||||
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。
|
||||
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。
|
||||
:return:
|
||||
"""
|
||||
assert backend in SUPPORTED_BACKENDS
|
||||
|
@ -74,7 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name):
|
||||
elif is_numpy_generic_class(ele_dtype):
|
||||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
|
||||
else:
|
||||
dtype == ele_dtype
|
||||
dtype = ele_dtype
|
||||
|
||||
return dtype
|
||||
|
||||
@ -174,6 +174,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
|
||||
"""
|
||||
shapes = get_shape(batch_field)
|
||||
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
|
||||
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val)
|
||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
|
||||
return tensor
|
||||
|
@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger):
|
||||
raise e
|
||||
finally:
|
||||
self.on_train_end()
|
||||
self.driver.barrier()
|
||||
|
||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
|
||||
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
|
||||
@ -441,6 +440,7 @@ class Trainer(TrainerEventTrigger):
|
||||
"""
|
||||
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
|
||||
_own_callbacks.extend(self._custom_callbacks[None])
|
||||
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
|
||||
self._custom_callbacks[None] = []
|
||||
if self.marker is not None:
|
||||
if len(self._custom_callbacks[self.marker]) == 0:
|
||||
|
@ -14,7 +14,7 @@ else:
|
||||
from fastNLP.core.dataset import DataSet as Dataset
|
||||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
|
||||
from fastNLP.core.collators import Collator
|
||||
from fastNLP.core.utils.utils import indice_collate_wrapper
|
||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
|
||||
from fastNLP.core.dataset import DataSet as FDataSet
|
||||
|
||||
|
||||
@ -106,33 +106,33 @@ class JittorDataLoader:
|
||||
return len(self.dataset) // self.dataset.batch_size
|
||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
|
||||
|
||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
|
||||
pad_fn: Callable = None) -> "JittorDataLoader":
|
||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
|
||||
pad_fn:Callable=None) -> Collator:
|
||||
"""
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
|
||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
|
||||
无意义。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
|
||||
backend=backend)
|
||||
return self
|
||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
|
||||
|
||||
def set_ignore(self, *field_names) -> "JittorDataLoader":
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
@ -145,18 +145,17 @@ class JittorDataLoader:
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_ignore(*field_names)
|
||||
return self
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
def get_batch_indices(self) -> List[int]:
|
||||
"""
|
||||
获取当前数据的idx
|
||||
获取当前 batch 的 idx
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self.cur_batch_indices
|
||||
|
||||
|
||||
def prepare_jittor_dataloader():
|
||||
...
|
||||
|
@ -15,8 +15,9 @@ else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
|
||||
|
||||
from fastNLP.core.collators.collator import Collator
|
||||
from fastNLP.core.utils.utils import indice_collate_wrapper
|
||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
|
||||
from fastNLP.core.dataset import DataSet as FDataSet
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler
|
||||
|
||||
|
||||
class _PaddleDataset(Dataset):
|
||||
@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader):
|
||||
if not isinstance(dataset, _PaddleDataset):
|
||||
dataset = _PaddleDataset(dataset)
|
||||
|
||||
if batch_sampler is None:
|
||||
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
|
||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
|
||||
return_list=return_list, batch_sampler=batch_sampler,
|
||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
|
||||
@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader):
|
||||
if isinstance(dataset.dataset, FDataSet):
|
||||
self._collate_fn = dataset.dataset.collator
|
||||
self._collate_fn.set_backend(backend="paddle")
|
||||
# if collate_fn is not None:
|
||||
# self._collate_fn.add_collator(collate_fn)
|
||||
else:
|
||||
self._collate_fn = Collator(backend="paddle")
|
||||
|
||||
@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader):
|
||||
self.cur_batch_indices = indices
|
||||
yield data
|
||||
|
||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
|
||||
pad_fn: Callable = None) -> "PaddleDataLoader":
|
||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
|
||||
pad_fn:Callable=None) -> Collator:
|
||||
"""
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
|
||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
|
||||
无意义。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
|
||||
backend=backend)
|
||||
return self
|
||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
|
||||
|
||||
def set_ignore(self, *field_names) -> "PaddleDataLoader":
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader):
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_ignore(*field_names)
|
||||
return self
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
def get_batch_indices(self) -> List[int]:
|
||||
"""
|
||||
获取当前数据的idx
|
||||
获取当前 batch 的 idx
|
||||
|
||||
:return:
|
||||
"""
|
||||
@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader):
|
||||
|
||||
|
||||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
|
||||
return_list: bool = True, batch_sampler=None,
|
||||
return_list: bool = True,
|
||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
|
||||
train_batch_size: int = 1, shuffle: bool = False,
|
||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
|
||||
num_workers: int = 0, use_buffer_reader: bool = True,
|
||||
|
@ -3,14 +3,14 @@ __all__ = [
|
||||
'prepare_torch_dataloader'
|
||||
]
|
||||
|
||||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
|
||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.collators import Collator
|
||||
from fastNLP.core.utils.utils import indice_collate_wrapper
|
||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.utils.data import DataLoader, Sampler
|
||||
@ -76,6 +76,9 @@ class TorchDataLoader(DataLoader):
|
||||
if not isinstance(dataset, _FDataSet):
|
||||
dataset = _FDataSet(dataset)
|
||||
|
||||
if sampler is None and batch_sampler is None:
|
||||
sampler = RandomSampler(dataset, shuffle=shuffle)
|
||||
|
||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
|
||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
|
||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
|
||||
@ -87,9 +90,6 @@ class TorchDataLoader(DataLoader):
|
||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
|
||||
self._collate_fn = dataset.dataset.collator
|
||||
self._collate_fn.set_backend(backend="torch")
|
||||
# if collate_fn is not None and collate_fn is not default_collate:
|
||||
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来
|
||||
# self._collate_fn.add_collator(collate_fn)
|
||||
else:
|
||||
self._collate_fn = Collator(backend="torch")
|
||||
else:
|
||||
@ -112,31 +112,32 @@ class TorchDataLoader(DataLoader):
|
||||
yield data
|
||||
|
||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
|
||||
pad_fn:Callable=None) -> "TorchDataLoader":
|
||||
pad_fn:Callable=None) -> Collator:
|
||||
"""
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
|
||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
|
||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
|
||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
|
||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
|
||||
无意义。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
:return: 返回 Collator 自身
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
|
||||
return self
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
|
||||
|
||||
def set_ignore(self, *field_names) -> "TorchDataLoader":
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
@ -149,24 +150,15 @@ class TorchDataLoader(DataLoader):
|
||||
"""
|
||||
if isinstance(self._collate_fn, Collator):
|
||||
self._collate_fn.set_ignore(*field_names)
|
||||
return self
|
||||
return self._collate_fn
|
||||
else:
|
||||
raise ValueError(f"collate_fn is not fastnlp collator")
|
||||
|
||||
def get_batch_indices(self) -> List[int]:
|
||||
"""
|
||||
获取当前数据的idx
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self.cur_batch_indices
|
||||
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
|
||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
|
||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
|
||||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
|
||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
|
||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
|
||||
|
16
fastNLP/core/dataloaders/utils.py
Normal file
16
fastNLP/core/dataloaders/utils.py
Normal file
@ -0,0 +1,16 @@
|
||||
def indice_collate_wrapper(func):
|
||||
"""
|
||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。
|
||||
|
||||
:param func: 需要修饰的函数
|
||||
:return:
|
||||
"""
|
||||
|
||||
def wrapper(tuple_data):
|
||||
indice, ins_list = [], []
|
||||
for idx, ins in tuple_data:
|
||||
indice.append(idx)
|
||||
ins_list.append(ins)
|
||||
return indice, func(ins_list)
|
||||
|
||||
return wrapper
|
@ -780,7 +780,7 @@ class DataSet:
|
||||
self.collator.set_ignore(*field_names)
|
||||
|
||||
@property
|
||||
def collator(self):
|
||||
def collator(self) -> Collator:
|
||||
if self._collator is None:
|
||||
self._collator = Collator()
|
||||
return self._collator
|
||||
|
@ -22,7 +22,7 @@ from fastNLP.core.utils import (
|
||||
rank_zero_rm
|
||||
)
|
||||
from fastNLP.core.samplers import (
|
||||
RandomBatchSampler,
|
||||
ReproduceBatchSampler,
|
||||
ReproducibleSampler,
|
||||
ReproducibleBatchSampler,
|
||||
RandomSampler,
|
||||
@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
|
||||
return self.model, model.forward
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
|
||||
reproducible: bool = False):
|
||||
r"""
|
||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。
|
||||
|
@ -22,7 +22,7 @@ from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleBatchSampler,
|
||||
ReproducibleSampler,
|
||||
RandomBatchSampler,
|
||||
ReproduceBatchSampler,
|
||||
RandomSampler,
|
||||
)
|
||||
|
||||
@ -345,7 +345,7 @@ class PaddleDriver(Driver):
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
|
||||
"`ReproducibleSampler`.")
|
||||
else:
|
||||
sampler = RandomBatchSampler(
|
||||
sampler = ReproduceBatchSampler(
|
||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
@ -476,7 +476,7 @@ class PaddleDriver(Driver):
|
||||
res.shuffle = True
|
||||
else:
|
||||
res.shuffle = False
|
||||
# RandomBatchSampler 的情况
|
||||
# ReproduceBatchSampler 的情况
|
||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
|
||||
batch_sampler = dataloader.batch_sampler.batch_sampler
|
||||
res.sampler = batch_sampler.sampler
|
||||
|
@ -14,7 +14,7 @@ from fastNLP.core.utils import (
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleBatchSampler,
|
||||
RandomBatchSampler,
|
||||
ReproduceBatchSampler,
|
||||
ReproducibleSampler,
|
||||
RandomSampler,
|
||||
re_instantiate_sampler,
|
||||
@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler = ReproduceBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
|
@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
|
||||
from fastNLP.core.samplers import RandomSampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
|
||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler = ReproduceBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
|
@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
|
||||
|
||||
|
||||
class TorchDriver(Driver):
|
||||
@ -293,7 +293,7 @@ class TorchDriver(Driver):
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
|
||||
"`ReproducibleSampler`.")
|
||||
else:
|
||||
sampler = RandomBatchSampler(
|
||||
sampler = ReproduceBatchSampler(
|
||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
@ -407,7 +407,7 @@ class TorchDriver(Driver):
|
||||
res.shuffle = True
|
||||
else:
|
||||
res.shuffle = False
|
||||
# RandomBatchSampler 的情况
|
||||
# ReproduceBatchSampler 的情况
|
||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"):
|
||||
batch_sampler = dataloader.batch_sampler.batch_sampler
|
||||
res.sampler = batch_sampler.sampler
|
||||
|
@ -14,9 +14,10 @@ __all__ = [
|
||||
"UnrepeatedSortedSampler",
|
||||
"UnrepeatedSequentialSampler",
|
||||
|
||||
"RandomBatchSampler",
|
||||
"ReproduceBatchSampler",
|
||||
"BucketedBatchSampler",
|
||||
"ReproducibleBatchSampler",
|
||||
"RandomBatchSampler",
|
||||
|
||||
"re_instantiate_sampler"
|
||||
]
|
||||
@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
|
||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
|
||||
from .utils import re_instantiate_sampler
|
||||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
|
||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
|
||||
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
__all__ = [
|
||||
'BucketedBatchSampler',
|
||||
"ReproduceBatchSampler",
|
||||
"RandomBatchSampler"
|
||||
]
|
||||
|
||||
@ -54,13 +55,13 @@ class ReproducibleBatchSampler:
|
||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
|
||||
|
||||
|
||||
class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
class ReproduceBatchSampler(ReproducibleBatchSampler):
|
||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
|
||||
"""
|
||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。
|
||||
|
||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
|
||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代
|
||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
|
||||
:param batch_size: 每个 batch 的大小是多少。
|
||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
|
||||
@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
self.need_reinitialize = False
|
||||
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
|
||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
|
||||
@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
|
||||
|
||||
|
||||
class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
|
||||
drop_last: bool = False, seed: int = 0, **kwargs):
|
||||
"""
|
||||
随机分 batch 的 batch_sampler 。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param batch_size: 每个 batch 的大小
|
||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
|
||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。
|
||||
:param seed: 设置的随机数种子
|
||||
:param kwargs: fastNLP 保留使用
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.dataset = dataset
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
|
||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量
|
||||
|
||||
# 多卡的相关的参数
|
||||
self.num_replicas = kwargs.get("num_replicas", 1)
|
||||
self.rank = kwargs.get("rank", 0)
|
||||
self.epoch = kwargs.get("epoch", -1)
|
||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;
|
||||
|
||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
|
||||
self.during_iter = kwargs.get("during_iter", False)
|
||||
|
||||
# 以下变量为内部使用恢复状态的变量。
|
||||
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
|
||||
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
|
||||
"during an unfinished iteration."
|
||||
assert num_replicas > 0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0 <= rank < num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.pad = pad
|
||||
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
|
||||
self.num_consumed_samples = 0
|
||||
self.during_iter = True
|
||||
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
if self.shuffle:
|
||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
|
||||
_batches = []
|
||||
for _i in range(self.old_num_replicas):
|
||||
_indices = indices[_i:len(indices):self.old_num_replicas]
|
||||
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
|
||||
_batches.append(__batches)
|
||||
batches = list(chain(*[_ for _ in zip(*_batches)]))
|
||||
indices = list(chain(*batches))
|
||||
indices = indices[self.num_consumed_samples:]
|
||||
# 取出这个 rank ,
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
|
||||
batches = list(map(list, batches))
|
||||
else:
|
||||
indices = indices[self.num_consumed_samples:]
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
_num_batches = len(indices) // self.batch_size
|
||||
if _num_batches == 0:
|
||||
batches = [indices]
|
||||
else:
|
||||
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
|
||||
if len(indices)%self.batch_size!=0:
|
||||
batches.append(indices[_num_batches*self.batch_size:])
|
||||
|
||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
|
||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
|
||||
if len(batches) > 0:
|
||||
if len(batches[-1])<self.batch_size:
|
||||
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
|
||||
else:
|
||||
batches.append([batches[-1][0]])
|
||||
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
|
||||
if len(batches):
|
||||
batches[-1].pop(-1)
|
||||
if len(batches[-1])==0:
|
||||
batches.pop(-1)
|
||||
|
||||
assert sum(map(len, batches)) == self.num_left_samples
|
||||
|
||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
|
||||
batches = batches[:-1]
|
||||
|
||||
for batch in batches:
|
||||
self.num_consumed_samples += self.num_replicas * len(batch)
|
||||
yield list(map(int, batch))
|
||||
self.during_iter = False
|
||||
self.num_consumed_samples = 0
|
||||
self.old_batch_size = self.batch_size
|
||||
self.old_num_replicas = self.num_replicas
|
||||
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
|
||||
self.epoch -= 1
|
||||
|
||||
def batchify(self, indices, batch_size, seed):
|
||||
"""
|
||||
将 indices 分为 batches
|
||||
|
||||
:param sorted_indices: List[int]
|
||||
:param batch_size: int
|
||||
:param seed: int
|
||||
:return: List[List[int]]
|
||||
"""
|
||||
# 实际的 bucket 大小
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
num_samples = 0
|
||||
batches = []
|
||||
while num_samples<len(indices):
|
||||
batches.append(indices[num_samples:num_samples+batch_size])
|
||||
num_samples += batch_size
|
||||
return batches
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
@property
|
||||
def batch_idx_in_epoch(self):
|
||||
if self.drop_last:
|
||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
|
||||
else:
|
||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
|
||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size
|
||||
|
||||
@property
|
||||
def total_size(self):
|
||||
"""
|
||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
|
||||
大于或者小于len(dataset)
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self.num_consumed_samples + self.num_replicas*self.num_left_samples
|
||||
|
||||
@property
|
||||
def num_left_samples(self):
|
||||
"""
|
||||
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。
|
||||
|
||||
:return:
|
||||
"""
|
||||
num_consumed_samples = self.num_consumed_samples
|
||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
|
||||
|
||||
def __len__(self)->int:
|
||||
"""
|
||||
返回当前 sampler 还会返回多少个 batch 的数据
|
||||
|
||||
:return:
|
||||
"""
|
||||
num_sampler_per_rank = self.total_size//self.num_replicas
|
||||
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
|
||||
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
|
||||
return num_batches
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
if self.old_batch_size != self.batch_size:
|
||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
|
||||
" consumed. ")
|
||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
|
||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
|
||||
'batch_size': self.batch_size,
|
||||
'num_replicas': self.num_replicas}
|
||||
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states: Dict):
|
||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
|
||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
|
||||
"during an unfinished iteration."
|
||||
|
||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
|
||||
"and current dataset."
|
||||
self.seed = states['seed']
|
||||
self.epoch = states['epoch']
|
||||
self.num_consumed_samples = states['num_consumed_samples']
|
||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
|
||||
self.num_consumed_samples = 0
|
||||
if self.shuffle != states['shuffle']:
|
||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
|
||||
f"we use shuffle={states['shuffle']}")
|
||||
self.shuffle = states["shuffle"]
|
||||
self.old_batch_size = states['batch_size']
|
||||
self.old_num_replicas = states['num_replicas']
|
||||
|
||||
|
||||
class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
|
||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
|
||||
|
@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler):
|
||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
|
||||
"""
|
||||
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器
|
||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。
|
||||
:param seed: 随机数种子。
|
||||
:param kwargs: 用户不需要使用,fastNLP 内部使用
|
||||
"""
|
||||
|
||||
super(RandomSampler, self).__init__()
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
@ -21,7 +21,6 @@ __all__ = [
|
||||
'nullcontext',
|
||||
'pretty_table_printer',
|
||||
'Option',
|
||||
'indice_collate_wrapper',
|
||||
'deprecated',
|
||||
'seq_len_to_mask',
|
||||
'rank_zero_rm',
|
||||
@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
|
||||
from .torch_utils import torch_move_data_to_device
|
||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
|
||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
|
||||
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
|
||||
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
|
||||
from ..dataloaders.utils import indice_collate_wrapper
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ import warnings
|
||||
from dataclasses import is_dataclass
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
|
||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
|
||||
from typing import Tuple, Optional
|
||||
from time import sleep
|
||||
|
||||
@ -35,7 +35,6 @@ __all__ = [
|
||||
'nullcontext',
|
||||
'pretty_table_printer',
|
||||
'Option',
|
||||
'indice_collate_wrapper',
|
||||
'deprecated',
|
||||
'seq_len_to_mask',
|
||||
'rank_zero_rm',
|
||||
@ -513,24 +512,6 @@ class Option(dict):
|
||||
self.update(state)
|
||||
|
||||
|
||||
def indice_collate_wrapper(func):
|
||||
"""
|
||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。
|
||||
|
||||
:param func: 需要修饰的函数
|
||||
:return:
|
||||
"""
|
||||
|
||||
def wrapper(tuple_data):
|
||||
indice, ins_list = [], []
|
||||
for idx, ins in tuple_data:
|
||||
indice.append(idx)
|
||||
ins_list.append(ins)
|
||||
return indice, func(ins_list)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
_emitted_deprecation_warnings = set()
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
|
||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
|
||||
dataset = PaddleNormalDataset()
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=RandomBatchSampler(
|
||||
batch_sampler=ReproduceBatchSampler(
|
||||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
|
||||
batch_size,
|
||||
drop_last,
|
||||
@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
|
||||
res = PaddleSingleDriver.get_dataloader_args(dataloader)
|
||||
|
||||
assert isinstance(res.dataset, PaddleNormalDataset)
|
||||
assert isinstance(res.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
|
||||
if shuffle:
|
||||
assert isinstance(res.sampler, paddle.io.RandomSampler)
|
||||
else:
|
||||
@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
|
||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
|
||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
|
||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
|
||||
@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
else:
|
||||
# 此时会替换 batch_sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
|
||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
|
||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
|
||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert replaced_loader.batch_sampler is dist
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=RandomBatchSampler(
|
||||
batch_sampler=ReproduceBatchSampler(
|
||||
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
|
||||
batch_size=4,
|
||||
drop_last=False,
|
||||
@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
else:
|
||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
|
||||
|
||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
|
||||
left_idxes = set()
|
||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
# 重新改造 dataloader
|
||||
new_loader = DataLoader(
|
||||
dataset=replaced_loader.dataset,
|
||||
batch_sampler=RandomBatchSampler(
|
||||
batch_sampler=ReproduceBatchSampler(
|
||||
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
dataset = PaddleRandomMaxDataset(40, 10)
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
|
||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
|
||||
)
|
||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
|
||||
|
||||
@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
# 更改 batch_size
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
|
||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
|
||||
)
|
||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
# 2. 检查 batch_sampler 是否被正确地加载和替换
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
|
||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4
|
||||
|
||||
|
@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
|
||||
replace_batch_sampler,
|
||||
replace_sampler,
|
||||
)
|
||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
|
||||
def test_replace_batch_sampler():
|
||||
dataset = PaddleNormalDataset(10)
|
||||
dataloader = DataLoader(dataset, batch_size=32)
|
||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
|
||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
|
||||
|
||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
|
||||
assert len(replaced_loader.dataset) == len(dataset)
|
||||
assert replaced_loader.batch_sampler.batch_size == 16
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
|
||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
|
||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
|
||||
"""
|
||||
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader
|
||||
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader
|
||||
"""
|
||||
if shuffle:
|
||||
sampler = torch.utils.data.RandomSampler(dataset)
|
||||
@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
|
||||
sampler = torch.utils.data.SequentialSampler(dataset)
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(
|
||||
batch_sampler=ReproduceBatchSampler(
|
||||
BatchSampler(
|
||||
sampler, batch_size=batch_size, drop_last=drop_last
|
||||
),
|
||||
@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
|
||||
res = TorchSingleDriver.get_dataloader_args(dataloader)
|
||||
|
||||
assert isinstance(res.dataset, TorchNormalDataset)
|
||||
assert isinstance(res.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
|
||||
if shuffle:
|
||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler)
|
||||
else:
|
||||
@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
|
||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
|
||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
|
||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
|
||||
@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
else:
|
||||
# 此时会替换 batch_sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
|
||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
|
||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
|
||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert replaced_loader.batch_sampler is dist
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
else:
|
||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
|
||||
|
||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
|
||||
left_idxes = set()
|
||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
# 重新改造 dataloader
|
||||
@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
# 2. 检查 batch_sampler 是否被正确地加载和替换
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
|
||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4
|
||||
|
||||
|
@ -30,7 +30,7 @@ class SequenceDataSet:
|
||||
|
||||
|
||||
def check_replace_sampler(driver):
|
||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
|
||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler
|
||||
# reproducible 是 True 和 False
|
||||
|
||||
# 需要 check 返回的 sampler 和 dataloader 都不同了
|
||||
|
@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
|
||||
replace_batch_sampler,
|
||||
replace_sampler,
|
||||
)
|
||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
def test_replace_batch_sampler():
|
||||
dataset = TorchNormalDataset(10)
|
||||
dataloader = DataLoader(dataset, batch_size=32)
|
||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
|
||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
|
||||
|
||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
|
||||
assert isinstance(replaced_loader.dataset, TorchNormalDataset)
|
||||
assert len(replaced_loader.dataset) == len(dataset)
|
||||
assert replaced_loader.batch_sampler.batch_size == 16
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from itertools import chain
|
||||
from copy import deepcopy
|
||||
|
||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
|
||||
@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
# before_batch_size = 7
|
||||
# dataset = TorchNormalDataset(num_of_data=100)
|
||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
#
|
||||
# forward_steps = 3
|
||||
@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
#
|
||||
# # 1. 保存状态
|
||||
# _get_re_batchsampler = dataloader.batch_sampler
|
||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
|
||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
||||
# state = _get_re_batchsampler.state_dict()
|
||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
|
||||
# "sampler_type": "RandomBatchSampler"}
|
||||
# "sampler_type": "ReproduceBatchSampler"}
|
||||
#
|
||||
# # 2. 断点重训,重新生成一个 dataloader;
|
||||
# # 不改变 batch_size;
|
||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler.load_state_dict(state)
|
||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
#
|
||||
@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
# # 改变 batch_size;
|
||||
# after_batch_size = 3
|
||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
|
||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler.load_state_dict(state)
|
||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
#
|
||||
@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
# dataset = TorchNormalDataset(num_of_data=100)
|
||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
#
|
||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
|
||||
@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
#
|
||||
# # 1. 保存状态
|
||||
# _get_re_batchsampler = dataloader.batch_sampler
|
||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
|
||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
|
||||
# state = _get_re_batchsampler.state_dict()
|
||||
#
|
||||
# # 2. 断点重训,重新生成一个 dataloader;
|
||||
# # 不改变 batch_size;
|
||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
# re_batchsampler.load_state_dict(state)
|
||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
#
|
||||
@ -511,3 +511,313 @@ class TestBucketedBatchSampler:
|
||||
already_seen_set.update(batch)
|
||||
|
||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)
|
||||
|
||||
|
||||
class TestRandomBatchSampler:
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
@pytest.mark.parametrize('drop_last', [True, False])
|
||||
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
|
||||
def test_single_num_batch(self, shuffle, drop_last, num):
|
||||
# 数量不够不报错
|
||||
for num in [2, 7, 14, 15, 70, 71]:
|
||||
dataset = DatasetWithVaryLength(num_of_data=num)
|
||||
before_batch_size = 7
|
||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle)
|
||||
count = len(list(iter(re_batchsampler)))
|
||||
if drop_last:
|
||||
assert count==num//before_batch_size, num
|
||||
else:
|
||||
assert count==(num+before_batch_size-1)//before_batch_size, num
|
||||
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
@pytest.mark.parametrize('drop_last', [True, False])
|
||||
def test_single(self, shuffle, drop_last):
|
||||
|
||||
before_batch_size = 7
|
||||
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4
|
||||
|
||||
dataset = DatasetWithVaryLength(num_of_data=1000)
|
||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle)
|
||||
re_batchsampler.set_epoch(0)
|
||||
forward_steps = 10
|
||||
iterator = iter(re_batchsampler)
|
||||
already_generate_indices = set()
|
||||
for _ in range(forward_steps):
|
||||
batch = next(iterator)
|
||||
already_generate_indices.update(batch)
|
||||
|
||||
# 1. 保存状态
|
||||
state = re_batchsampler.state_dict()
|
||||
|
||||
# 2. 断点重训,继续训练
|
||||
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle)
|
||||
re_batchsampler2.load_state_dict(state)
|
||||
re_batchsampler2.set_epoch(0)
|
||||
new_already_generate_indices = set()
|
||||
mask = np.ones(len(dataset), dtype=bool)
|
||||
mask[list(already_generate_indices)] = 0
|
||||
indices = np.arange(len(dataset))[mask]
|
||||
max_diff = -1
|
||||
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
|
||||
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
|
||||
for batch in re_batchsampler2:
|
||||
for b in batch:
|
||||
assert b not in already_generate_indices
|
||||
new_already_generate_indices.update(batch)
|
||||
if drop_last is False:
|
||||
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)
|
||||
|
||||
# 改变 batch_size;
|
||||
after_batch_size = 3
|
||||
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
|
||||
drop_last=drop_last,
|
||||
shuffle=shuffle)
|
||||
re_batchsampler3.load_state_dict(state)
|
||||
re_batchsampler3.set_epoch(0)
|
||||
count = 0
|
||||
|
||||
mask = np.ones(len(dataset), dtype=bool)
|
||||
mask[list(already_generate_indices)] = 0
|
||||
indices = np.arange(len(dataset))[mask]
|
||||
|
||||
for batch in re_batchsampler3:
|
||||
for b in batch:
|
||||
assert b not in already_generate_indices
|
||||
already_generate_indices.update(batch)
|
||||
count += 1
|
||||
if count > 5:
|
||||
break
|
||||
|
||||
# 再 save ,不允许再上个epoch没结束继续sample
|
||||
after_batch_size = 5
|
||||
with pytest.raises(RuntimeError):
|
||||
state = re_batchsampler3.state_dict()
|
||||
|
||||
for batch in re_batchsampler3: # consume all, 这样才能save
|
||||
pass
|
||||
|
||||
already_generate_indices = set()
|
||||
count = 0
|
||||
for batch in re_batchsampler3: # 重新开始
|
||||
for b in batch:
|
||||
assert b not in already_generate_indices
|
||||
already_generate_indices.update(batch)
|
||||
count += 1
|
||||
if count > 5:
|
||||
break
|
||||
|
||||
state = re_batchsampler3.state_dict()
|
||||
# 这里的 drop_last 为 False,需要最终是所有 sample
|
||||
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
|
||||
drop_last=False,
|
||||
shuffle=shuffle)
|
||||
re_batchsampler4.load_state_dict(state)
|
||||
re_batchsampler4.set_epoch(0)
|
||||
|
||||
mask = np.ones(len(dataset), dtype=bool)
|
||||
mask[list(already_generate_indices)] = 0
|
||||
for batch in re_batchsampler4:
|
||||
for b in batch:
|
||||
assert b not in already_generate_indices
|
||||
already_generate_indices.update(batch)
|
||||
|
||||
assert len(already_generate_indices) == len(dataset)
|
||||
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
@pytest.mark.parametrize('drop_last', [True, False])
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
def test_multi(self, shuffle, drop_last, pad):
|
||||
# def test_multi(self, shuffle=True, drop_last=False, pad=False):
|
||||
|
||||
# no shuffle
|
||||
num_replica = 2
|
||||
dataset = DatasetWithVaryLength(num_of_data=1000)
|
||||
batch_size = 5
|
||||
num_batch_per_bucket = 10
|
||||
lengths = []
|
||||
rank0_already_seen_indexes = None
|
||||
max_diff = num_batch_per_bucket * batch_size * num_replica
|
||||
for rank in range(num_replica):
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
|
||||
shuffle = shuffle, drop_last=drop_last)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replica, rank=rank, pad=pad)
|
||||
lengths.append(len(sampler))
|
||||
already_seen_indexes = set()
|
||||
repeat_count = 0
|
||||
for batch in sampler:
|
||||
for b in batch:
|
||||
repeat_count += int(b in already_seen_indexes)
|
||||
if rank0_already_seen_indexes: # 不能交叉出现
|
||||
assert b not in rank0_already_seen_indexes
|
||||
already_seen_indexes.update(batch)
|
||||
if rank0_already_seen_indexes is None:
|
||||
rank0_already_seen_indexes = already_seen_indexes
|
||||
if pad: # 应该允许重复一次
|
||||
assert repeat_count<=1
|
||||
else:
|
||||
assert repeat_count==0
|
||||
|
||||
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致
|
||||
|
||||
# 多进程的保存
|
||||
already_seen_indexes = set()
|
||||
for rank in range(num_replica):
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
|
||||
shuffle = shuffle, drop_last=drop_last)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replica, rank=rank, pad=pad)
|
||||
lengths.append(len(sampler))
|
||||
count = 0
|
||||
for batch in sampler:
|
||||
already_seen_indexes.update(batch)
|
||||
if count>5:
|
||||
break
|
||||
count += 1
|
||||
state = sampler.state_dict()
|
||||
|
||||
# 切换成单机
|
||||
new_batch_size = 6
|
||||
num_batch_per_bucket = 3
|
||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
new_sampler.load_state_dict(state)
|
||||
repeat_count = 0
|
||||
new_already_seen_indexes = set(list(already_seen_indexes))
|
||||
|
||||
mask = np.ones(len(dataset), dtype=bool)
|
||||
mask[list(already_seen_indexes)] = 0
|
||||
indices = np.arange(len(dataset))[mask]
|
||||
|
||||
for batch in new_sampler:
|
||||
for b in batch:
|
||||
repeat_count += int(b in new_already_seen_indexes)
|
||||
new_already_seen_indexes.update(batch)
|
||||
if pad: # 应该允许重复一次
|
||||
assert repeat_count <= 1
|
||||
else:
|
||||
assert repeat_count == 0
|
||||
if drop_last is False: # 如果没有drop应该相等
|
||||
assert len(new_already_seen_indexes)==len(dataset)
|
||||
|
||||
# 测试替换卡的数量。
|
||||
num_replica = 3
|
||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
new_sampler.set_epoch(0)
|
||||
new_sampler.load_state_dict(state)
|
||||
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
|
||||
repeat_count = 0
|
||||
|
||||
mask = np.ones(len(dataset), dtype=bool)
|
||||
mask[list(already_seen_indexes)] = 0
|
||||
indices = np.arange(len(dataset))[mask]
|
||||
|
||||
for batch in new_sampler:
|
||||
for b in batch:
|
||||
repeat_count += int(b in already_seen_indexes)
|
||||
if pad: # 应该允许重复一次
|
||||
assert repeat_count <= 1
|
||||
else:
|
||||
assert repeat_count == 0
|
||||
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
@pytest.mark.parametrize('drop_last', [True, False])
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
|
||||
@pytest.mark.parametrize('num_replicas', [2, 3])
|
||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
|
||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
batch_size = 6
|
||||
if num_replicas*batch_size > num_samples:
|
||||
return
|
||||
num_batch_per_bucket = 10
|
||||
samplers = []
|
||||
lengths = []
|
||||
for i in range(num_replicas):
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
sampler.set_distributed(num_replicas, rank=i, pad=pad)
|
||||
sampler.set_epoch(0)
|
||||
samplers.append(sampler)
|
||||
lengths.append(len(list(iter(sampler))))
|
||||
assert len(set(lengths))==1
|
||||
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
@pytest.mark.parametrize('drop_last', [True, False])
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
|
||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
|
||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
|
||||
"""
|
||||
测试是否能够正确地恢复使用过的(forward)数据
|
||||
|
||||
:return:
|
||||
"""
|
||||
batch_size = 6
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
samplers = []
|
||||
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
|
||||
for i in range(num_replicas):
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
|
||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
|
||||
samplers.append(sampler)
|
||||
count = 0
|
||||
already_seen_sets = [set()]
|
||||
already_seen_set = set()
|
||||
for batchs in zip(*samplers):
|
||||
batch = chain(*batchs)
|
||||
already_seen_set.update(batch)
|
||||
already_seen_sets.append(deepcopy(already_seen_set))
|
||||
count += 1
|
||||
if count > 3:
|
||||
break
|
||||
states = samplers[0].state_dict()
|
||||
for i in range(len(already_seen_sets)):
|
||||
states['num_consumed_samples'] = num_consumed_samples_array[i]
|
||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
|
||||
shuffle=shuffle, drop_last=drop_last)
|
||||
sampler.set_epoch(0)
|
||||
already_seen_set = deepcopy(already_seen_sets[i])
|
||||
for batch in sampler:
|
||||
already_seen_set.update(batch)
|
||||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
|
||||
dataset)
|
||||
|
||||
# 测试保存之后再次保存
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
sampler.set_epoch(0)
|
||||
states['num_consumed_samples'] = num_consumed_samples_array[2]
|
||||
if len(already_seen_sets)<3:
|
||||
return
|
||||
already_seen_set = already_seen_sets[2]
|
||||
count = 0
|
||||
for batch in sampler:
|
||||
already_seen_set.update(batch)
|
||||
count += 1
|
||||
if count > 6:
|
||||
break
|
||||
|
||||
states = sampler.state_dict()
|
||||
num_consumed_samples_array = list(range(len(dataset)))
|
||||
states['num_consumed_samples'] = num_consumed_samples_array[count]
|
||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
sampler.load_state_dict(states)
|
||||
sampler.set_epoch(0)
|
||||
for batch in sampler:
|
||||
already_seen_set.update(batch)
|
||||
|
||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)
|
||||
|
Loading…
Reference in New Issue
Block a user