From 9bb1ed4ccf4e41107249e609a0ef11808171c52c Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 20 May 2022 18:04:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20mixdataloader=20=E6=96=87?= =?UTF-8?q?=E6=A1=A3=EF=BC=8C=20=E4=BF=AE=E6=94=B9mix=5Fsampler,=20mixdata?= =?UTF-8?q?loader=E4=BB=A3=E7=A0=81=EF=BC=8C=20=E5=A2=9E=E5=8A=A0=E7=9B=B8?= =?UTF-8?q?=E5=BA=94=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/dataloaders/jittor_dataloader/fdl.py | 4 +- fastNLP/core/dataloaders/mix_dataloader.py | 169 +++--- fastNLP/core/samplers/mix_sampler.py | 343 ++++-------- tests/core/dataloaders/test_mixdataloader.py | 495 ++++++++++++++++++ 4 files changed, 700 insertions(+), 311 deletions(-) create mode 100644 tests/core/dataloaders/test_mixdataloader.py diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index da896cf8..349fb444 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet class _JittorDataset(Dataset): """ 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset - + """ def __init__(self, dataset) -> None: @@ -37,7 +37,7 @@ class _JittorDataset(Dataset): item = item.tolist() return (item, self.dataset[item]) - + class JittorDataLoader: """ 提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, diff --git a/fastNLP/core/dataloaders/mix_dataloader.py b/fastNLP/core/dataloaders/mix_dataloader.py index d6f6a9be..29b0cd0b 100644 --- a/fastNLP/core/dataloaders/mix_dataloader.py +++ b/fastNLP/core/dataloaders/mix_dataloader.py @@ -2,13 +2,14 @@ __all__ = [ 'MixDataLoader' ] -from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence +from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping import numpy as np from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.collators import Collator if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler @@ -18,12 +19,13 @@ else: class _MixDataset: """ - 将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx + 将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index """ + def __init__(self, datasets: list = None) -> None: """ - :param datasets: 数据集的列表 + :param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列 """ self.datasets = datasets # 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 @@ -35,7 +37,7 @@ class _MixDataset: def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: """ - 根据index索引获取数据 + 根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回 :param idx: 整数类型的index或者列表 :return: @@ -69,8 +71,9 @@ class _MixCollateFn: 存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 """ - def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, - auto_collators: Optional[List[Callable]] = None) -> None: + + def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None: + if isinstance(collate_fns, Sequence): self.collate_fns = lambda idx, lst: collate_fns[idx](lst) elif callable(collate_fns): @@ -78,96 +81,124 @@ class _MixCollateFn: else: self.collate_fns = lambda idx, lst: lst - self.collate_fns = collate_fns - self.auto_collators = auto_collators - def __call__(self, ins_list: List) -> Dict: """ 调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 + :param ins_list: :return: """ _ins_list, _ds_index = [], 0 for ins, _ds_index in ins_list: _ins_list.append(ins) - # auto_collate先处理 - if self.auto_collators is not None: - _ins_list = self.auto_collators[_ds_index](_ins_list) _ins_list = self.collate_fns(_ds_index, _ins_list) return _ins_list class MixDataLoader(DataLoader): """ - 针对一下三种情况提供的MixDataLoader: - 1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 - 2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 - 3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset - 采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 + 针对一下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: + + * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 + 接一个的 sample 完所有数据。 + * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample + 混合数据集 datasets 的数据组成一个 batch 序列返回。 + * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, + 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 + * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; + 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0 None: + ds_ratio: Union[None, str, Dict[str, float]] = None, + pin_memory: bool = False) -> None: """ - :param datasets: dataset的列表 - :param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, - 当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 - 对象,每次迭代返回的是List[int] - :param collate_fn: 对取得到的数据进行打包的callable函数, - 当其为callable类型时候,所有数据集采样的数据都会经过这个函数; - 当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, - 其对应关系与datasets位置匹配; - 当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 - :param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 - sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; - sampler为List[Sampler],则datasets也为List,且一一对应 - sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 - :param num_workers: 进程的数量,当num_workers=0时不开启多进程 - :param batch_size: 批次大小, datasets的所有数据集batch_size一致 - :param drop_last: 是否去掉最后一个不符合batch_size的数据 - :param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 - 当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 - 当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 - 当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, - 其大于0,可以超过1,将数据集重采样翻倍即可 - 当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 + :param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 + :param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: + + * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 + 接一个的 sample 完所有数据。 + * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample + 混合数据集 datasets 的数据组成一个 batch 序列返回。 + * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, + 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 + * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; + 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 + 也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 + :param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: + + * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 + * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 + 到最短长度``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``。 + * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 + 到最大长度``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 + * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, + 代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 """ - # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, - if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ - isinstance(sampler, Dict): - raise ValueError(f"") + # sampler 为 dict,则判断是否与 datasets 的 key 相同 + if isinstance(sampler, Dict): + for key in datasets.keys(): + if not sampler[key]: + raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") + # collate_fn 为 dict,则判断是否与 datasets 的 key 相同 + if isinstance(collate_fn, Dict): + if mode == 'mix': + raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'") + for key in datasets.keys(): + if not collate_fn[key]: + raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!") - if isinstance(collate_fn, list): - if len(collate_fn) != len(datasets): - raise ValueError("the length of collate_fn != datasets!!") + if isinstance(collate_fn, str) and collate_fn == 'auto': + date_type = None + for idx, ds in enumerate(datasets.values()): + if idx == 0: + date_type = type(ds[0]) + if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)): + raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。" + f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}") - if isinstance(sampler, list): - if len(sampler) != len(datasets): - raise ValueError("the length of sampler != datasets!!") + collate_fn = Collator(backend='torch') - # Dict类型转化为List,以便于_MixCollateFn处理 + # Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset if isinstance(collate_fn, Dict): collate_fn = [fn for _, fn in collate_fn.items()] - # 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 - if isinstance(datasets, Dict): - dataset = [ds for _, ds in datasets.items()] - else: - dataset = datasets - auto_collators = [] - for per_ds in dataset: - if isinstance(per_ds, DataSet): - auto_collators.append(per_ds.get_collator()) - else: - # 如果没有对应的collator就设置一个不做任何操作的collator - auto_collators.append(lambda x: x) + dataset = [ds for _, ds in datasets.items()] + + # 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题 + collate_fn = _MixCollateFn(collate_fn) - # List类型的collate_fn只有两种情况,需要对其进行包裹 - collate_fn = _MixCollateFn(collate_fn, auto_collators) if mode == 'sequential': batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, drop_last=drop_last, ds_ratio=ds_ratio) diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py index 0aa543be..774b5c7c 100644 --- a/fastNLP/core/samplers/mix_sampler.py +++ b/fastNLP/core/samplers/mix_sampler.py @@ -21,9 +21,9 @@ class MixSampler: mix_sampler的基类 """ - def __init__(self, dataset: Union[List, Dict], batch_size: int = None, - sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, - ds_ratio: Union[str, List[float], Dict[str, float]] = None, + def __init__(self, dataset: Dict, batch_size: int = None, + sampler: Union[Dict[str, "Sampler"], None, str] = None, + ds_ratio: Union[str, Dict[str, float]] = None, drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: """ @@ -32,9 +32,12 @@ class MixSampler: :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size """ - # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, - if isinstance(dataset, Dict) and isinstance(sampler, List): - raise ValueError(f"{sampler} must be dict") + # sampler 为 dict,则判断是否与 datasets 的 key 相同 + if isinstance(sampler, Dict): + for key in dataset.keys(): + if not sampler[key]: + raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") + if batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) @@ -46,15 +49,7 @@ class MixSampler: raise ValueError("if rank>=0 and word_size>=0, sampler must be str") if sampler is None and (word_size < 0 or rank < 0): - if isinstance(dataset, List): - self.sampler = [SequentialSampler(ds) for ds in dataset] - elif isinstance(dataset, Dict): - self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} - - elif isinstance(sampler, List): - if len(sampler) != len(dataset): - raise ValueError("the length of sampler != the length of sampler") - self.sampler = sampler + self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} elif isinstance(sampler, Dict): self.sampler = sampler @@ -68,26 +63,7 @@ class MixSampler: # 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len sampler_lens, total_lens, sampler_index = [], 0, [] - if isinstance(self.sampler, List): - if ds_ratio is None: - sampler_lens = [len(spl) for spl in self.sampler] - - elif ds_ratio == 'pad_to_most': - sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif ds_ratio == 'truncate_to_least': - sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif isinstance(ds_ratio, List): - if not all(item >= 0 for item in ds_ratio): - raise ValueError("batch_size should be a positive integer value, " - "but got ds_ratio={}".format(ds_ratio)) - sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)] - else: - raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") - total_lens = sum(sampler_lens) - - elif isinstance(self.sampler, Dict): + if isinstance(self.sampler, Dict): if ds_ratio is None: sampler_lens = [len(spl) for _, spl in self.sampler.items()] @@ -100,7 +76,7 @@ class MixSampler: sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len elif isinstance(ds_ratio, Dict): - if not all(item >= 0 for item in ds_ratio): + if not all([item >= 0 for item in ds_ratio.values()]): raise ValueError("batch_size should be a positive integer value, " "but got ds_ratio={}".format(ds_ratio)) sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] @@ -108,7 +84,7 @@ class MixSampler: raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") total_lens = sum(sampler_lens) - # sampler为str时候,初始化下移到iter方法中 + # sampler 为 str 时候,初始化下移到 iter 方法中 if len(sampler_lens) > 0: sampler_index = [sampler_lens[0]] for idx in sampler_lens[1:]: @@ -160,75 +136,37 @@ class DopedSampler(MixSampler): """ 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 """ - def __init__(self, dataset: Union[List, Dict], batch_size: int = None, - sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, - ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, + def __init__(self, dataset: Dict, batch_size: int = None, + sampler: Union[Dict[str, "Sampler"], str] = None, + ds_ratio: Union[str, None, Dict[str, float]] = None, drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, sampler=sampler, ds_ratio=ds_ratio, drop_last=drop_last, rank=rank, word_size=word_size) def __iter__(self) -> List[int]: - # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 + # sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化 if isinstance(self.sampler, str): if self.sampler == 'seq': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(list(range(len(per_ds))))) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))) + self.sampler = {} + for name, per_ds in self.datasets.items(): + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))) elif self.sampler == 'rand': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(indices)) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(indices) + self.sampler = {} + for name, per_ds in self.datasets.items(): + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(per_ds), generator=g).tolist() + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(indices) # 根据给定的ds_ratio计算真正需要处理数据集 - if isinstance(self.sampler, List): - if self.ds_ratio is None: - sampler_lens = [len(spl) for spl in self.sampler] - - elif self.ds_ratio == 'pad_to_most': - sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif self.ds_ratio == 'truncate_to_least': - sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif isinstance(self.ds_ratio, List): - if not all(item >= 0 for item in self.ds_ratio): - raise ValueError("batch_size should be a positive integer value, " - "but got ds_ratio={}".format(self.ds_ratio)) - sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] - else: - raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") - total_lens = sum(sampler_lens) - - elif isinstance(self.sampler, Dict): + if isinstance(self.sampler, Dict): if self.ds_ratio is None: sampler_lens = [len(spl) for _, spl in self.sampler.items()] @@ -257,11 +195,11 @@ class DopedSampler(MixSampler): sampler_index.append(temp + idx) self.num_samplers = sampler_index self.len_samplers = total_lens - # 每个batch的数据, 总的数据量total_index, 每个数据集的samplers + # 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers batch_idx, samplers = [], [] # 如果单机则用所有数据,否则采用多卡 if self.rank < 0 or self.word_size < 0: - # 根据sampler长度判断是否使用unsigned int 或者unsigned long + # 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long if self.len_samplers > 42e8: total_index = array.array('L', list(range(self.len_samplers))) else: @@ -274,15 +212,17 @@ class DopedSampler(MixSampler): else: total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) + start_idx = 0 + + # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) + for idx, (name, spl) in enumerate(self.sampler.items()): + end_idx = len(spl) + samplers.append((iter(spl), name, start_idx)) + start_idx += end_idx # 根据sampler的类型取出每个数据集的sampler - if isinstance(self.sampler, List): - sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] - samplers = [(iter(spl), idx, base_index) - for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))] - else: - sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] - samplers = [(iter(spl), name, sampler_base_index[idx]) - for idx, (name, spl) in enumerate(self.sampler.items())] + # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] + # samplers = [(iter(spl), name, sampler_base_index[idx]) + # for idx, (name, spl) in enumerate(self.sampler.items())] # 生成随机数 np.random.seed(self.epoch) np.random.shuffle(total_index) @@ -295,7 +235,7 @@ class DopedSampler(MixSampler): # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration spl = iter(self.sampler[name]) batch_idx.append(next(spl) + base_index) - samplers[name] = (spl, name, base_index) + samplers[ds_index] = (spl, name, base_index) if len(batch_idx) == self.batch_size: yield batch_idx batch_idx = [] @@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler): # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 if isinstance(self.sampler, str): if self.sampler == 'seq': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(list(range(len(per_ds))))) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))) + self.sampler = {} + for name, per_ds in self.datasets.items(): + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))) elif self.sampler == 'rand': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(indices)) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(indices) - # 根据给定的ds_ratio计算真正需要处理数据集 - if isinstance(self.sampler, List): - if self.ds_ratio is None: - sampler_lens = [len(spl) for spl in self.sampler] + self.sampler = {} + for name, per_ds in self.datasets.items(): + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(per_ds), generator=g).tolist() + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(indices) - elif self.ds_ratio == 'pad_to_most': - sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif self.ds_ratio == 'truncate_to_least': - sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif isinstance(self.ds_ratio, List): - if not all(item >= 0 for item in self.ds_ratio): - raise ValueError("batch_size should be a positive integer value, " - "but got ds_ratio={}".format(self.ds_ratio)) - sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] - else: - raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") - total_lens = sum(sampler_lens) - - elif isinstance(self.sampler, Dict): + # 根据给定的 ds_ratio 算真正需要处理数据集 + if isinstance(self.sampler, Dict): if self.ds_ratio is None: sampler_lens = [len(spl) for _, spl in self.sampler.items()] @@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler): self.len_samplers = total_lens batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] - if isinstance(self.sampler, List): - if self.word_size > 0 and self.rank >= 0: - sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1] - else: - sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] - samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in - enumerate(zip(self.sampler, sampler_base_index))] - else: - if self.word_size > 0 and self.rank >= 0: - sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] - else: - sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] + start_idx = 0 - samplers = [(iter(spl), name, sampler_base_index[idx]) - for idx, (name, spl) in enumerate(self.sampler.items())] + # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) + for idx, (name, spl) in enumerate(self.sampler.items()): + end_idx = len(spl) + samplers.append((iter(spl), name, start_idx)) + start_idx += end_idx + # if self.word_size > 0 and self.rank >= 0: + # sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] + # else: + # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] + # + # samplers = [(iter(spl), name, sampler_base_index[idx]) + # for idx, (name, spl) in enumerate(self.sampler.items())] for idx in total_index: ds_index = np.searchsorted(self.num_samplers, idx, side='right') @@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler): # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration spl = iter(self.sampler[name]) batch_idx.append(next(spl) + base_index) - samplers[name] = (spl, name, base_index) + samplers[ds_index] = (spl, name, base_index) if len(batch_idx) == self.batch_size: yield batch_idx batch_idx = [] @@ -506,63 +408,26 @@ class PollingSampler(MixSampler): # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 if isinstance(self.sampler, str): if self.sampler == 'seq': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(list(range(len(per_ds))))) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(list(range(len(per_ds)))) + self.sampler = {} + for name, per_ds in self.datasets.items(): + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(list(range(len(per_ds)))) elif self.sampler == 'rand': - if isinstance(self.datasets, List): - self.sampler = [] - for per_ds in self.datasets: - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) - else: - self.sampler.append(InnerSampler(indices)) - elif isinstance(self.datasets, Dict): - self.sampler = {} - for name, per_ds in self.datasets.items(): - g = torch.Generator() - g.manual_seed(self.epoch) - indices = torch.randperm(len(per_ds), generator=g).tolist() - if self.word_size >= 0 and self.rank >= 0: - self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) - else: - self.sampler[name] = InnerSampler(indices) + + self.sampler = {} + for name, per_ds in self.datasets.items(): + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(per_ds), generator=g).tolist() + if self.word_size >= 0 and self.rank >= 0: + self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) + else: + self.sampler[name] = InnerSampler(indices) # 根据给定的ds_ratio计算真正需要处理数据集 - if isinstance(self.sampler, List): - if self.ds_ratio is None: - sampler_lens = [len(spl) for spl in self.sampler] - - elif self.ds_ratio == 'pad_to_most': - sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif self.ds_ratio == 'truncate_to_least': - sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) - - elif isinstance(self.ds_ratio, List): - if not all(item >= 0 for item in self.ds_ratio): - raise ValueError("batch_size should be a positive integer value, " - "but got ds_ratio={}".format(self.ds_ratio)) - sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] - else: - raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") - total_lens = sum(sampler_lens) - - elif isinstance(self.sampler, Dict): + if isinstance(self.sampler, Dict): if self.ds_ratio is None: sampler_lens = [len(spl) for _, spl in self.sampler.items()] @@ -592,17 +457,15 @@ class PollingSampler(MixSampler): self.num_samplers = sampler_index self.len_samplers = total_lens - start_idx, samplers = 0, [] - if isinstance(self.sampler, List): - # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) - for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)): - samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx)) - start_idx = end_idx - else: - for idx, (name, spl) in enumerate(self.sampler.items()): - end_idx = self.num_samplers[idx] - samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name)) - start_idx = end_idx + start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0 + + # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) + for idx, (name, spl) in enumerate(self.sampler.items()): + end_idx = len(spl) + true_end_idx = self.num_samplers[idx] + samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name)) + start_idx += end_idx + true_start_idx = true_end_idx while True: # 退出循环 diff --git a/tests/core/dataloaders/test_mixdataloader.py b/tests/core/dataloaders/test_mixdataloader.py new file mode 100644 index 00000000..35872b39 --- /dev/null +++ b/tests/core/dataloaders/test_mixdataloader.py @@ -0,0 +1,495 @@ +import pytest +from typing import Mapping + +from fastNLP.core.dataloaders import MixDataLoader +from fastNLP import DataSet +from fastNLP.core.collators import Collator +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + from torch.utils.data import default_collate, SequentialSampler, RandomSampler + +d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + +d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10}) + +d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) + + +def test_pad_val(tensor, val=0): + if isinstance(tensor, torch.Tensor): + tensor = tensor.tolist() + for item in tensor: + if item[-1] > 0: + continue + elif item[-1] != val: + return False + return True + + +class TestMixDataLoader: + + def test_sequential_init(self): + datasets = {'d1': d1, 'd2': d2, 'd3': d3} + # drop_last = True, collate_fn = 'auto + dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True) + for idx, batch in enumerate(dl): + if idx == 0: + # d1 + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + # d2 + assert batch['x'].shape == torch.Size([16, 3]) + if idx > 1: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + + # collate_fn = Callable + def collate_batch(batch): + new_batch = {'x': [], 'y': []} + for ins in batch: + new_batch['x'].append(ins['x']) + new_batch['y'].append(ins['y']) + return new_batch + + dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True) + for idx, batch in enumerate(dl1): + if idx == 0: + # d1 + assert [1, 2] in batch['x'] + if idx == 1: + # d2 + assert [101, 201] in batch['x'] + if idx > 1: + # d3 + assert [1000, 2000] in batch['x'] + assert 'x' in batch and 'y' in batch + + collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), + 'd2': Collator(backend='auto').set_pad("x", -2), + 'd3': Collator(backend='auto').set_pad("x", -3)} + dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) + for idx, batch in enumerate(dl2): + if idx == 0: + assert test_pad_val(batch['x'], val=-1) + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + assert test_pad_val(batch['x'], val=-2) + assert batch['x'].shape == torch.Size([16, 3]) + if idx > 1: + assert test_pad_val(batch['x'], val=-3) + assert batch['x'].shape == torch.Size([16, 4]) + + # sampler 为 str + dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True) + dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True) + for idx, batch in enumerate(dl3): + if idx == 0: + # d1 + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape == torch.Size([16, 3]) + if idx == 2: + # d3 + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + if idx > 1: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + + for idx, batch in enumerate(dl4): + if idx == 0: + # d1 + assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + # d2 + assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape == torch.Size([16, 3]) + if idx == 2: + # d3 + assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + if idx > 1: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + + # sampler 为 Dict + samplers = {'d1': SequentialSampler(d1), + 'd2': SequentialSampler(d2), + 'd3': RandomSampler(d3)} + dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True) + for idx, batch in enumerate(dl5): + if idx == 0: + # d1 + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape == torch.Size([16, 3]) + if idx > 1: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + + # ds_ratio 为 'truncate_to_least' + dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) + for idx, batch in enumerate(dl6): + if idx == 0: + # d1 + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape == torch.Size([16, 4]) + if idx == 1: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape == torch.Size([16, 3]) + if idx == 2: + # d3 + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + if idx > 2: + raise ValueError(f"ds_ratio: 'truncate_to_least' error") + + # ds_ratio 为 'pad_to_most' + dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True) + for idx, batch in enumerate(dl7): + if idx < 18: + # d1 + assert batch['x'].shape == torch.Size([16, 4]) + if 18 <= idx < 36: + # d2 + assert batch['x'].shape == torch.Size([16, 3]) + if 36 <= idx < 54: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + if idx >= 54: + raise ValueError(f"ds_ratio: 'pad_to_most' error") + + # ds_ratio 为 Dict[str, float] + ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} + dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) + for idx, batch in enumerate(dl8): + if idx < 1: + # d1 + assert batch['x'].shape == torch.Size([16, 4]) + if 1 <= idx < 4: + # d2 + assert batch['x'].shape == torch.Size([16, 3]) + if 4 <= idx < 41: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + assert test_pad_val(batch['x'], val=0) + if idx >= 41: + raise ValueError(f"ds_ratio: 'pad_to_most' error") + + ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} + dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) + for idx, batch in enumerate(dl9): + if idx < 1: + # d2 + assert batch['x'].shape == torch.Size([16, 3]) + if 1 <= idx < 19: + # d3 + assert batch['x'].shape == torch.Size([16, 4]) + + assert test_pad_val(batch['x'], val=0) + if idx >= 19: + raise ValueError(f"ds_ratio: 'pad_to_most' error") + + def test_mix(self): + datasets = {'d1': d1, 'd2': d2, 'd3': d3} + dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) + for idx, batch in enumerate(dl): + assert test_pad_val(batch['x'], val=0) + if idx >= 22: + raise ValueError(f"out of range") + + # collate_fn = Callable + def collate_batch(batch): + new_batch = {'x': [], 'y': []} + for ins in batch: + new_batch['x'].append(ins['x']) + new_batch['y'].append(ins['y']) + return new_batch + + dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) + for idx, batch in enumerate(dl1): + assert isinstance(batch['x'], list) + assert test_pad_val(batch['x'], val=0) + if idx >= 22: + raise ValueError(f"out of range") + + collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), + 'd2': Collator(backend='auto').set_pad("x", -2), + 'd3': Collator(backend='auto').set_pad("x", -3)} + with pytest.raises(ValueError): + MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns) + + # sampler 为 str + dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) + for idx, batch in enumerate(dl3): + assert test_pad_val(batch['x'], val=0) + if idx >= 22: + raise ValueError(f"out of range") + dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) + for idx, batch in enumerate(dl4): + assert test_pad_val(batch['x'], val=0) + if idx >= 22: + raise ValueError(f"out of range") + # sampler 为 Dict + samplers = {'d1': SequentialSampler(d1), + 'd2': SequentialSampler(d2), + 'd3': RandomSampler(d3)} + dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) + for idx, batch in enumerate(dl5): + assert test_pad_val(batch['x'], val=0) + if idx >= 22: + raise ValueError(f"out of range") + # ds_ratio 为 'truncate_to_least' + dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least') + d1_len, d2_len, d3_len = 0, 0, 0 + for idx, batch in enumerate(dl6): + for item in batch['y'].tolist(): + if item in [1, 0, 1]: + d1_len += 1 + elif item in [20, 10, 10]: + d2_len += 1 + elif item in [100, 100, 200]: + d3_len += 1 + if idx >= 6: + raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了") + assert d1_len == d2_len == d3_len == 30 + + # ds_ratio 为 'pad_to_most' + dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most') + d1_len, d2_len, d3_len = 0, 0, 0 + for idx, batch in enumerate(dl7): + for item in batch['y'].tolist(): + if item in [1, 0, 1]: + d1_len += 1 + elif item in [20, 10, 10]: + d2_len += 1 + elif item in [100, 100, 200]: + d3_len += 1 + + if idx >= 57: + raise ValueError(f"ds_ratio 为 'pad_to_most'出错了") + assert d1_len == d2_len == d3_len == 300 + + # ds_ratio 为 Dict[str, float] + ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} + dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) + d1_len, d2_len, d3_len = 0, 0, 0 + for idx, batch in enumerate(dl8): + for item in batch['y'].tolist(): + if item in [1, 0, 1]: + d1_len += 1 + elif item in [20, 10, 10]: + d2_len += 1 + elif item in [100, 100, 200]: + d3_len += 1 + if idx >= 44: + raise ValueError(f"ds_ratio 为 'Dict'出错了") + assert d1_len == 30 + assert d2_len == 60 + assert d3_len == 600 + + ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} + dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) + d1_len, d2_len, d3_len = 0, 0, 0 + for idx, batch in enumerate(dl9): + for item in batch['y'].tolist(): + if item in [1, 0, 1]: + d1_len += 1 + elif item in [20, 10, 10]: + d2_len += 1 + elif item in [100, 100, 200]: + d3_len += 1 + if idx >= 21: + raise ValueError(f"ds_ratio 为 'Dict'出错了") + + def test_polling(self): + datasets = {'d1': d1, 'd2': d2, 'd3': d3} + dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18) + for idx, batch in enumerate(dl): + if idx == 0 or idx == 3: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or 4 < idx <= 20: + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx > 20: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + + # collate_fn = Callable + def collate_batch(batch): + new_batch = {'x': [], 'y': []} + for ins in batch: + new_batch['x'].append(ins['x']) + new_batch['y'].append(ins['y']) + return new_batch + + dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18) + for idx, batch in enumerate(dl1): + if idx == 0 or idx == 3: + assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] + elif idx == 1 or idx == 4: + # d2 + assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]] + elif idx == 2 or 4 < idx <= 20: + assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]] + if idx > 20: + raise ValueError(f"out of range") + + collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), + 'd2': Collator(backend='auto').set_pad("x", -2), + 'd3': Collator(backend='auto').set_pad("x", -3)} + dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) + for idx, batch in enumerate(dl1): + if idx == 0 or idx == 3: + assert test_pad_val(batch['x'], val=-1) + assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert test_pad_val(batch['x'], val=-2) + assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or 4 < idx <= 20: + assert test_pad_val(batch['x'], val=-3) + assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx > 20: + raise ValueError(f"out of range") + + # sampler 为 str + dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18) + dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18) + for idx, batch in enumerate(dl2): + if idx == 0 or idx == 3: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or 4 < idx <= 20: + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx > 20: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + for idx, batch in enumerate(dl3): + if idx == 0 or idx == 3: + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert batch['x'].shape[1] == 3 + elif idx == 2 or 4 < idx <= 20: + assert batch['x'].shape[1] == 4 + if idx > 20: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + # sampler 为 Dict + samplers = {'d1': SequentialSampler(d1), + 'd2': SequentialSampler(d2), + 'd3': RandomSampler(d3)} + dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18) + for idx, batch in enumerate(dl4): + if idx == 0 or idx == 3: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or 4 < idx <= 20: + assert batch['x'].shape[1] == 4 + if idx > 20: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + + # ds_ratio 为 'truncate_to_least' + dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) + for idx, batch in enumerate(dl5): + if idx == 0 or idx == 3: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or idx == 5: + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx > 5: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + + # ds_ratio 为 'pad_to_most' + dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) + for idx, batch in enumerate(dl6): + if idx % 3 == 0: + # d1 + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + if idx % 3 == 1: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + if idx % 3 == 2: + # d3 + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx >= 51: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + + # ds_ratio 为 Dict[str, float] + ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} + dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) + for idx, batch in enumerate(dl7): + if idx == 0 or idx == 3: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1 or idx == 4 or idx == 6 or idx == 8: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx == 2 or idx == 5 or idx == 7 or idx > 8: + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + if idx > 39: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) + + ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} + dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) + for idx, batch in enumerate(dl8): + if idx == 0: + assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] + assert batch['x'].shape[1] == 4 + elif idx == 1: + # d2 + assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] + assert batch['x'].shape[1] == 3 + elif idx > 1: + assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] + assert batch['x'].shape[1] == 4 + + if idx > 18: + raise ValueError(f"out of range") + test_pad_val(batch['x'], val=0) \ No newline at end of file