增加 mixdataloader 文档, 修改mix_sampler, mixdataloader代码, 增加相应测试用例

This commit is contained in:
MorningForest 2022-05-20 18:04:00 +08:00
parent 916c113322
commit 9bb1ed4ccf
4 changed files with 700 additions and 311 deletions

View File

@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet
class _JittorDataset(Dataset):
"""
对用户传的dataset进行封装以便JittorDataLoader能够支持使用自定义的dataset
"""
def __init__(self, dataset) -> None:
@ -37,7 +37,7 @@ class _JittorDataset(Dataset):
item = item.tolist()
return (item, self.dataset[item])
class JittorDataLoader:
"""
提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad

View File

@ -2,13 +2,14 @@ __all__ = [
'MixDataLoader'
]
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping
import numpy as np
from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators import Collator
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
@ -18,12 +19,13 @@ else:
class _MixDataset:
"""
将所有数据集当成一个混合大数据集来对待实现的__getitem__能区别每个数据idx
将所有数据集当成一个混合大数据集来对待 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index
"""
def __init__(self, datasets: list = None) -> None:
"""
:param datasets: 数据集的列表
:param datasets: 实现了 __getitem__() __len__() 的对象的序列
"""
self.datasets = datasets
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置
@ -35,7 +37,7 @@ class _MixDataset:
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]:
"""
根据index索引获取数据
根据index索引获取数据 能够跟 idx 的范围定位属于哪个小数据并返回
:param idx: 整数类型的index或者列表
:return:
@ -69,8 +71,9 @@ class _MixCollateFn:
存在多个auto_collate和多个collate_fn时候对一个批次数据集应用哪个auto_collate和collate_fn的问题
"""
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None,
auto_collators: Optional[List[Callable]] = None) -> None:
def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None:
if isinstance(collate_fns, Sequence):
self.collate_fns = lambda idx, lst: collate_fns[idx](lst)
elif callable(collate_fns):
@ -78,96 +81,124 @@ class _MixCollateFn:
else:
self.collate_fns = lambda idx, lst: lst
self.collate_fns = collate_fns
self.auto_collators = auto_collators
def __call__(self, ins_list: List) -> Dict:
"""
调用一次该方法我们将ins_list视为同一个数据集采样出来的故ds_index只能为一种
:param ins_list:
:return:
"""
_ins_list, _ds_index = [], 0
for ins, _ds_index in ins_list:
_ins_list.append(ins)
# auto_collate先处理
if self.auto_collators is not None:
_ins_list = self.auto_collators[_ds_index](_ins_list)
_ins_list = self.collate_fns(_ds_index, _ins_list)
return _ins_list
class MixDataLoader(DataLoader):
"""
针对一下三种情况提供的MixDataLoader:
1. 给定datasets集合或者列表顺序采样datasets处理采样完首个dataset后取出第二个dataset重复上面过程直至datasets取完
2. 给定datasets集合或者列表随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户直至datasets所有数据集采样完
3. 给定datasets集合或者列表轮流采样datasets即是循环遍历datasets每取出一个dataset采样一个batch的数据然后取出下一个dataset
采样一个batch数据重复上述过程直至某个dataset采样结束或者所有dataset采样结束
针对一下四种情况提供的 ``MixDataLoader`` 目前只支持 ``torch`` 框架的版本 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``:
* mode ``'sequential'`` ``MixDataLoader`` datasets 的序列或者字典视为一个混合大数据集 按照 datasets 数据集序列或者字典的顺序一个
接一个的 sample 完所有数据
* mode ``'mix'`` ``MixDataLoader`` datasets 的序列或者字典视为一个混合大数据集 然后根据用户输入的 idx 序列随机sample
混合数据集 datasets 的数据组成一个 batch 序列返回
* mode ``'polling'`` ``MixDataLoader`` 按照 datasets 数据集的顺序 先从第一个数据集采样一个 batch 的数据返回
再从第二数据集采样一个 batch 数据返回 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回直至所有的数据集都被轮询的采样完
* mode ``"Sampler"`` Sampler 是实现 __iter__() 的实例化对象 其功能是每次 iter 时返回一个 batch 序列 其类型为 List[int];
Sampler 必须将输入的 datasets 视为一个混合大数据集 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数
sampler, drop_last, ds_ratio 均无效
"""
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential',
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None,
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None,
def __init__(self, datasets: Dict = None, mode: Union[str, "Sampler"] = 'sequential',
collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto',
sampler: Union[Dict[str, "Sampler"], str, None] = None,
num_workers: int = 0, batch_size: int = 16, drop_last=False,
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None,
pin_memory: bool = True) -> None:
ds_ratio: Union[None, str, Dict[str, float]] = None,
pin_memory: bool = False) -> None:
"""
:param datasets: dataset的列表
:param mode: mode包括四种类型前三种分别为"sequential", "mix", "polling"分别代表上述三种情况
当mode为Sampler时为用户定制此时samplerds_ratiobatch_sizedrop_last失效此时Sampler应该是一个可迭代
对象每次迭代返回的是List[int]
:param collate_fn: 对取得到的数据进行打包的callable函数
当其为callable类型时候所有数据集采样的数据都会经过这个函数
当其为List[Callable]类型时datasets也应该为List会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数
其对应关系与datasets位置匹配
当其为Dict[str, Callable]类型时 datasets也是Dict类型且一一对应
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象
sampler为None时候datasets包含的每个dataset都会初始化一个sequentialSampler用于采样;
sampler为List[Sampler]则datasets也为List且一一对应
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应
:param num_workers: 进程的数量当num_workers=0时不开启多进程
:param batch_size: 批次大小, datasets的所有数据集batch_size一致
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param ds_ratio: 当ds_ratio为None原有数据集不进行扩充
当ds_ratio为'truncate_to_least'以datasets的最短数据集为基准将其他数据集截断到一样长度
当ds_ratio为'pad_to_most'以datasets的最长数据集为基准将最短数据集重采样到最长数据集长度一致为止
当ds_ratio为List[float]datasets也为Listds_ratio的每一个参数都是datasets每个数据集应该采样的倍数
其大于0可以超过1将数据集重采样翻倍即可
当ds_ratio为Dict[str, float]datasets也为Dict参数相互对应
:param datasets: 实现了 __getitem__() __len__() 对象的序列或者字典
:param mode: mode 控制 ``MixDataLoader`` 运行模式 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``:
* mode ``'sequential'`` ``MixDataLoader`` datasets 的序列或者字典视为一个混合大数据集 按照 datasets 数据集序列或者字典的顺序一个
接一个的 sample 完所有数据
* mode ``'mix'`` ``MixDataLoader`` datasets 的序列或者字典视为一个混合大数据集 然后根据用户输入的 idx 序列随机sample
混合数据集 datasets 的数据组成一个 batch 序列返回
* mode ``'polling'`` ``MixDataLoader`` 按照 datasets 数据集的顺序 先从第一个数据集采样一个 batch 的数据返回
再从第二数据集采样一个 batch 数据返回 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回直至所有的数据集都被轮询的采样完
* mode ``"Sampler"`` Sampler 是实现 __iter__() 的实例化对象 其功能是每次 iter 时返回一个 batch 序列 其类型为 List[int];
Sampler 必须将输入的 datasets 视为一个混合大数据集 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数
sampler, drop_last, ds_ratio 均无效
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``:
* collate_fn ``'auto'`` , ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class: `~fastNLP.core.collators.Collator` 作为其默认值
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用否则只能用户自己定义 collate_fn .
* collate_fn ``Callable`` collate_fn 会被 datasets 序列或者dict 的所有数据所共享 Callable 函数应当接受一个 batch 参数作为输入
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据 Callable 函数还应当返回一个对象
* collate_fn ``Dict[str, Callable]`` datasets key 必须和 callable_fn key 一致 ``MixDataLoader`` 会将 ``collate_fn[key]``
用到 ``datasets[key]`` 的数据集上 ``collate_fn[key]`` 是一个 Callable 对象
:param sampler: 实现了 __len__() __iter__() 的实例化对象 __iter__() 方法每次都会返回 dataset 的一个下标 index 其取值范围为
``[None, str, Dict[str, "Sampler"]]``:
* sampler ``None`` ``MixDataLoader`` 默认初始化 ``torch`` ``SequentialSampler`` 作为默认值其功能时顺序返回 dataset 的下标
* sampler ``str`` sampler 选择范围为 ``[rand, seq]`` sampler ``rand`` ``MixDataLoader`` 默认初始化 ``torch`` ``RandomSampler``
作为默认值 其功能时随机采样 dataset 的下标并返回 sampler ``seq`` ``MixDataLoader`` 默认初始化 ``torch`` ``SequentialSampler`` 作为默认值其功能时顺序返回 dataset 的下标
* sampler ``Dict[str, "Sampler"]`` ``Sampler`` 为用户定义的实现了 __len__() __iter__() 的实例化对象 其每次 iter 必须返回一个 int 下标
Dict str 必须和 datasets key 一致 也即是 ``Dict[str, Sampler] `` datasets 字典的每个 dataset 初始化勒一个 Sampler
:param num_workers: ``num_workers > 0`` , ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据 可以加快数据处理速度但同时
也消耗大量内存 ``num_workers=0`` 不开启子进程 默认为 ``0``
:param batch_size: 批次大小默认为 ``16`` 且当 batch_sampler None 有效 datasets 上所有 dataset batch_size 一致
:param drop_last: ``drop_last=True`` ``MixDataLoader`` 会扔掉 datasets 每个 dataset 最后一个长度小于 ``batch_size`` batch 数据;
``drop_last=False`` , 则会返回该 batch 数据 默认为 ``False``
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``:
* ds_ratio ``None``, datasets 数据集序列或字典不进行数据扩充处理
* ds_ratio ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len`` 其他数据集会被切断
到最短长度``mix_len``这种切断不是物理上切断``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``
* ds_ratio ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充
到最大长度``mix_len``这种扩充不是物理上扩充 ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``
* ds_ratio ``Dict[str, float]`` datasets 类型也必须为 ``Dict[str, DataSet]``, key 一一对应 ds_ratio value 是任意大于 0 的浮点数
代表着 datasets value 数据进行扩充或者缩减的倍数
"""
# 如果dataset为Dict则其他参数如collate_fn必须为Dict或者Callable,
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \
isinstance(sampler, Dict):
raise ValueError(f"")
# sampler 为 dict则判断是否与 datasets 的 key 相同
if isinstance(sampler, Dict):
for key in datasets.keys():
if not sampler[key]:
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
# collate_fn 为 dict则判断是否与 datasets 的 key 相同
if isinstance(collate_fn, Dict):
if mode == 'mix':
raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'")
for key in datasets.keys():
if not collate_fn[key]:
raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!")
if isinstance(collate_fn, list):
if len(collate_fn) != len(datasets):
raise ValueError("the length of collate_fn != datasets!!")
if isinstance(collate_fn, str) and collate_fn == 'auto':
date_type = None
for idx, ds in enumerate(datasets.values()):
if idx == 0:
date_type = type(ds[0])
if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)):
raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。"
f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}")
if isinstance(sampler, list):
if len(sampler) != len(datasets):
raise ValueError("the length of sampler != datasets!!")
collate_fn = Collator(backend='torch')
# Dict类型转化为List以便于_MixCollateFn处理
# Dict 类型的 collate_fn 转化为 List以便于 _MixCollateFn 里面根据 idx 定位 dataset
if isinstance(collate_fn, Dict):
collate_fn = [fn for _, fn in collate_fn.items()]
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测
if isinstance(datasets, Dict):
dataset = [ds for _, ds in datasets.items()]
else:
dataset = datasets
auto_collators = []
for per_ds in dataset:
if isinstance(per_ds, DataSet):
auto_collators.append(per_ds.get_collator())
else:
# 如果没有对应的collator就设置一个不做任何操作的collator
auto_collators.append(lambda x: x)
dataset = [ds for _, ds in datasets.items()]
# 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题
collate_fn = _MixCollateFn(collate_fn)
# List类型的collate_fn只有两种情况需要对其进行包裹
collate_fn = _MixCollateFn(collate_fn, auto_collators)
if mode == 'sequential':
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler,
drop_last=drop_last, ds_ratio=ds_ratio)

View File

@ -21,9 +21,9 @@ class MixSampler:
mix_sampler的基类
"""
def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None,
ds_ratio: Union[str, List[float], Dict[str, float]] = None,
def __init__(self, dataset: Dict, batch_size: int = None,
sampler: Union[Dict[str, "Sampler"], None, str] = None,
ds_ratio: Union[str, Dict[str, float]] = None,
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
"""
@ -32,9 +32,12 @@ class MixSampler:
:param sampler: 实例化好的sampler每个dataset对应一个sampler对象
:param drop_last: 是否去掉最后一个batch的数据其长度小于batch_size
"""
# 如果dataset为Dict则其他参数如collate_fn必须为Dict或者Callable,
if isinstance(dataset, Dict) and isinstance(sampler, List):
raise ValueError(f"{sampler} must be dict")
# sampler 为 dict则判断是否与 datasets 的 key 相同
if isinstance(sampler, Dict):
for key in dataset.keys():
if not sampler[key]:
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
if batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
@ -46,15 +49,7 @@ class MixSampler:
raise ValueError("if rank>=0 and word_size>=0, sampler must be str")
if sampler is None and (word_size < 0 or rank < 0):
if isinstance(dataset, List):
self.sampler = [SequentialSampler(ds) for ds in dataset]
elif isinstance(dataset, Dict):
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
elif isinstance(sampler, List):
if len(sampler) != len(dataset):
raise ValueError("the length of sampler != the length of sampler")
self.sampler = sampler
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
elif isinstance(sampler, Dict):
self.sampler = sampler
@ -68,26 +63,7 @@ class MixSampler:
# 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len
sampler_lens, total_lens, sampler_index = [], 0, []
if isinstance(self.sampler, List):
if ds_ratio is None:
sampler_lens = [len(spl) for spl in self.sampler]
elif ds_ratio == 'pad_to_most':
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
elif ds_ratio == 'truncate_to_least':
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
elif isinstance(ds_ratio, List):
if not all(item >= 0 for item in ds_ratio):
raise ValueError("batch_size should be a positive integer value, "
"but got ds_ratio={}".format(ds_ratio))
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)]
else:
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List")
total_lens = sum(sampler_lens)
elif isinstance(self.sampler, Dict):
if isinstance(self.sampler, Dict):
if ds_ratio is None:
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
@ -100,7 +76,7 @@ class MixSampler:
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
elif isinstance(ds_ratio, Dict):
if not all(item >= 0 for item in ds_ratio):
if not all([item >= 0 for item in ds_ratio.values()]):
raise ValueError("batch_size should be a positive integer value, "
"but got ds_ratio={}".format(ds_ratio))
sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()]
@ -108,7 +84,7 @@ class MixSampler:
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List")
total_lens = sum(sampler_lens)
# sampler为str时候初始化下移到iter方法中
# sampler str 时候,初始化下移到 iter 方法中
if len(sampler_lens) > 0:
sampler_index = [sampler_lens[0]]
for idx in sampler_lens[1:]:
@ -160,75 +136,37 @@ class DopedSampler(MixSampler):
"""
定制给MixDataLoader的BatchSampler其功能是将传入的datasets的list列表混合采样组成一个个batch返回
"""
def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None,
ds_ratio: Union[str, None, List[float], Dict[str, float]] = None,
def __init__(self, dataset: Dict, batch_size: int = None,
sampler: Union[Dict[str, "Sampler"], str] = None,
ds_ratio: Union[str, None, Dict[str, float]] = None,
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size,
sampler=sampler, ds_ratio=ds_ratio,
drop_last=drop_last, rank=rank, word_size=word_size)
def __iter__(self) -> List[int]:
# sampler为str 此时为单机多卡或者单机可以实现rand随机化
# sampler str 此时为单机多卡或者单机,可以实现 rand 随机化
if isinstance(self.sampler, str):
if self.sampler == 'seq':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
elif self.sampler == 'rand':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(indices))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
# 根据给定的ds_ratio计算真正需要处理数据集
if isinstance(self.sampler, List):
if self.ds_ratio is None:
sampler_lens = [len(spl) for spl in self.sampler]
elif self.ds_ratio == 'pad_to_most':
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
elif self.ds_ratio == 'truncate_to_least':
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
elif isinstance(self.ds_ratio, List):
if not all(item >= 0 for item in self.ds_ratio):
raise ValueError("batch_size should be a positive integer value, "
"but got ds_ratio={}".format(self.ds_ratio))
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
else:
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
total_lens = sum(sampler_lens)
elif isinstance(self.sampler, Dict):
if isinstance(self.sampler, Dict):
if self.ds_ratio is None:
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
@ -257,11 +195,11 @@ class DopedSampler(MixSampler):
sampler_index.append(temp + idx)
self.num_samplers = sampler_index
self.len_samplers = total_lens
# 每个batch的数据, 总的数据量total_index, 每个数据集的samplers
# 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers
batch_idx, samplers = [], []
# 如果单机则用所有数据,否则采用多卡
if self.rank < 0 or self.word_size < 0:
# 根据sampler长度判断是否使用unsigned int 或者unsigned long
# 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long
if self.len_samplers > 42e8:
total_index = array.array('L', list(range(self.len_samplers)))
else:
@ -274,15 +212,17 @@ class DopedSampler(MixSampler):
else:
total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size])
start_idx = 0
# 特定数据集需要长度特定数据集sampler, 特定数据集的基址, 特定sampler的下标
for idx, (name, spl) in enumerate(self.sampler.items()):
end_idx = len(spl)
samplers.append((iter(spl), name, start_idx))
start_idx += end_idx
# 根据sampler的类型取出每个数据集的sampler
if isinstance(self.sampler, List):
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1]
samplers = [(iter(spl), idx, base_index)
for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))]
else:
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
samplers = [(iter(spl), name, sampler_base_index[idx])
for idx, (name, spl) in enumerate(self.sampler.items())]
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
# samplers = [(iter(spl), name, sampler_base_index[idx])
# for idx, (name, spl) in enumerate(self.sampler.items())]
# 生成随机数
np.random.seed(self.epoch)
np.random.shuffle(total_index)
@ -295,7 +235,7 @@ class DopedSampler(MixSampler):
# 重新初始化一个新的sampler因为不可能为空故一定不会出现stopIteration
spl = iter(self.sampler[name])
batch_idx.append(next(spl) + base_index)
samplers[name] = (spl, name, base_index)
samplers[ds_index] = (spl, name, base_index)
if len(batch_idx) == self.batch_size:
yield batch_idx
batch_idx = []
@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler):
# sampler为str 此时为单机多卡或者单机可以实现rand随机化
if isinstance(self.sampler, str):
if self.sampler == 'seq':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
elif self.sampler == 'rand':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(indices))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
# 根据给定的ds_ratio计算真正需要处理数据集
if isinstance(self.sampler, List):
if self.ds_ratio is None:
sampler_lens = [len(spl) for spl in self.sampler]
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
elif self.ds_ratio == 'pad_to_most':
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
elif self.ds_ratio == 'truncate_to_least':
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
elif isinstance(self.ds_ratio, List):
if not all(item >= 0 for item in self.ds_ratio):
raise ValueError("batch_size should be a positive integer value, "
"but got ds_ratio={}".format(self.ds_ratio))
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
else:
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
total_lens = sum(sampler_lens)
elif isinstance(self.sampler, Dict):
# 根据给定的 ds_ratio 算真正需要处理数据集
if isinstance(self.sampler, Dict):
if self.ds_ratio is None:
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler):
self.len_samplers = total_lens
batch_idx, total_index, samplers = [], list(range(self.len_samplers)), []
if isinstance(self.sampler, List):
if self.word_size > 0 and self.rank >= 0:
sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1]
else:
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1]
samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in
enumerate(zip(self.sampler, sampler_base_index))]
else:
if self.word_size > 0 and self.rank >= 0:
sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1]
else:
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
start_idx = 0
samplers = [(iter(spl), name, sampler_base_index[idx])
for idx, (name, spl) in enumerate(self.sampler.items())]
# 特定数据集需要长度特定数据集sampler, 特定数据集的基址, 特定sampler的下标
for idx, (name, spl) in enumerate(self.sampler.items()):
end_idx = len(spl)
samplers.append((iter(spl), name, start_idx))
start_idx += end_idx
# if self.word_size > 0 and self.rank >= 0:
# sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1]
# else:
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
#
# samplers = [(iter(spl), name, sampler_base_index[idx])
# for idx, (name, spl) in enumerate(self.sampler.items())]
for idx in total_index:
ds_index = np.searchsorted(self.num_samplers, idx, side='right')
@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler):
# 重新初始化一个新的sampler因为不可能为空故一定不会出现stopIteration
spl = iter(self.sampler[name])
batch_idx.append(next(spl) + base_index)
samplers[name] = (spl, name, base_index)
samplers[ds_index] = (spl, name, base_index)
if len(batch_idx) == self.batch_size:
yield batch_idx
batch_idx = []
@ -506,63 +408,26 @@ class PollingSampler(MixSampler):
# sampler为str 此时为单机多卡或者单机可以实现rand随机化
if isinstance(self.sampler, str):
if self.sampler == 'seq':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
self.sampler = {}
for name, per_ds in self.datasets.items():
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
elif self.sampler == 'rand':
if isinstance(self.datasets, List):
self.sampler = []
for per_ds in self.datasets:
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
else:
self.sampler.append(InnerSampler(indices))
elif isinstance(self.datasets, Dict):
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
self.sampler = {}
for name, per_ds in self.datasets.items():
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(per_ds), generator=g).tolist()
if self.word_size >= 0 and self.rank >= 0:
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
else:
self.sampler[name] = InnerSampler(indices)
# 根据给定的ds_ratio计算真正需要处理数据集
if isinstance(self.sampler, List):
if self.ds_ratio is None:
sampler_lens = [len(spl) for spl in self.sampler]
elif self.ds_ratio == 'pad_to_most':
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
elif self.ds_ratio == 'truncate_to_least':
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
elif isinstance(self.ds_ratio, List):
if not all(item >= 0 for item in self.ds_ratio):
raise ValueError("batch_size should be a positive integer value, "
"but got ds_ratio={}".format(self.ds_ratio))
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
else:
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
total_lens = sum(sampler_lens)
elif isinstance(self.sampler, Dict):
if isinstance(self.sampler, Dict):
if self.ds_ratio is None:
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
@ -592,17 +457,15 @@ class PollingSampler(MixSampler):
self.num_samplers = sampler_index
self.len_samplers = total_lens
start_idx, samplers = 0, []
if isinstance(self.sampler, List):
# 特定数据集需要长度特定数据集sampler, 特定数据集的基址, 特定sampler的下标
for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)):
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx))
start_idx = end_idx
else:
for idx, (name, spl) in enumerate(self.sampler.items()):
end_idx = self.num_samplers[idx]
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name))
start_idx = end_idx
start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0
# 特定数据集需要长度特定数据集sampler, 特定数据集的基址, 特定sampler的下标
for idx, (name, spl) in enumerate(self.sampler.items()):
end_idx = len(spl)
true_end_idx = self.num_samplers[idx]
samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name))
start_idx += end_idx
true_start_idx = true_end_idx
while True:
# 退出循环

View File

@ -0,0 +1,495 @@
import pytest
from typing import Mapping
from fastNLP.core.dataloaders import MixDataLoader
from fastNLP import DataSet
from fastNLP.core.collators import Collator
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.utils.data import default_collate, SequentialSampler, RandomSampler
d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10})
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})
def test_pad_val(tensor, val=0):
if isinstance(tensor, torch.Tensor):
tensor = tensor.tolist()
for item in tensor:
if item[-1] > 0:
continue
elif item[-1] != val:
return False
return True
class TestMixDataLoader:
def test_sequential_init(self):
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
# drop_last = True, collate_fn = 'auto
dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True)
for idx, batch in enumerate(dl):
if idx == 0:
# d1
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
# d2
assert batch['x'].shape == torch.Size([16, 3])
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
# collate_fn = Callable
def collate_batch(batch):
new_batch = {'x': [], 'y': []}
for ins in batch:
new_batch['x'].append(ins['x'])
new_batch['y'].append(ins['y'])
return new_batch
dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True)
for idx, batch in enumerate(dl1):
if idx == 0:
# d1
assert [1, 2] in batch['x']
if idx == 1:
# d2
assert [101, 201] in batch['x']
if idx > 1:
# d3
assert [1000, 2000] in batch['x']
assert 'x' in batch and 'y' in batch
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
'd2': Collator(backend='auto').set_pad("x", -2),
'd3': Collator(backend='auto').set_pad("x", -3)}
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
for idx, batch in enumerate(dl2):
if idx == 0:
assert test_pad_val(batch['x'], val=-1)
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
assert test_pad_val(batch['x'], val=-2)
assert batch['x'].shape == torch.Size([16, 3])
if idx > 1:
assert test_pad_val(batch['x'], val=-3)
assert batch['x'].shape == torch.Size([16, 4])
# sampler 为 str
dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True)
dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True)
for idx, batch in enumerate(dl3):
if idx == 0:
# d1
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape == torch.Size([16, 3])
if idx == 2:
# d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
for idx, batch in enumerate(dl4):
if idx == 0:
# d1
assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
# d2
assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape == torch.Size([16, 3])
if idx == 2:
# d3
assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2),
'd3': RandomSampler(d3)}
dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True)
for idx, batch in enumerate(dl5):
if idx == 0:
# d1
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape == torch.Size([16, 3])
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
# ds_ratio 为 'truncate_to_least'
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
for idx, batch in enumerate(dl6):
if idx == 0:
# d1
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape == torch.Size([16, 3])
if idx == 2:
# d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
if idx > 2:
raise ValueError(f"ds_ratio: 'truncate_to_least' error")
# ds_ratio 为 'pad_to_most'
dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True)
for idx, batch in enumerate(dl7):
if idx < 18:
# d1
assert batch['x'].shape == torch.Size([16, 4])
if 18 <= idx < 36:
# d2
assert batch['x'].shape == torch.Size([16, 3])
if 36 <= idx < 54:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
if idx >= 54:
raise ValueError(f"ds_ratio: 'pad_to_most' error")
# ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
for idx, batch in enumerate(dl8):
if idx < 1:
# d1
assert batch['x'].shape == torch.Size([16, 4])
if 1 <= idx < 4:
# d2
assert batch['x'].shape == torch.Size([16, 3])
if 4 <= idx < 41:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
if idx >= 41:
raise ValueError(f"ds_ratio: 'pad_to_most' error")
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
for idx, batch in enumerate(dl9):
if idx < 1:
# d2
assert batch['x'].shape == torch.Size([16, 3])
if 1 <= idx < 19:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
if idx >= 19:
raise ValueError(f"ds_ratio: 'pad_to_most' error")
def test_mix(self):
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
for idx, batch in enumerate(dl):
assert test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# collate_fn = Callable
def collate_batch(batch):
new_batch = {'x': [], 'y': []}
for ins in batch:
new_batch['x'].append(ins['x'])
new_batch['y'].append(ins['y'])
return new_batch
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
for idx, batch in enumerate(dl1):
assert isinstance(batch['x'], list)
assert test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
'd2': Collator(backend='auto').set_pad("x", -2),
'd3': Collator(backend='auto').set_pad("x", -3)}
with pytest.raises(ValueError):
MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns)
# sampler 为 str
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
for idx, batch in enumerate(dl3):
assert test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
for idx, batch in enumerate(dl4):
assert test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2),
'd3': RandomSampler(d3)}
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
for idx, batch in enumerate(dl5):
assert test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# ds_ratio 为 'truncate_to_least'
dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least')
d1_len, d2_len, d3_len = 0, 0, 0
for idx, batch in enumerate(dl6):
for item in batch['y'].tolist():
if item in [1, 0, 1]:
d1_len += 1
elif item in [20, 10, 10]:
d2_len += 1
elif item in [100, 100, 200]:
d3_len += 1
if idx >= 6:
raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了")
assert d1_len == d2_len == d3_len == 30
# ds_ratio 为 'pad_to_most'
dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most')
d1_len, d2_len, d3_len = 0, 0, 0
for idx, batch in enumerate(dl7):
for item in batch['y'].tolist():
if item in [1, 0, 1]:
d1_len += 1
elif item in [20, 10, 10]:
d2_len += 1
elif item in [100, 100, 200]:
d3_len += 1
if idx >= 57:
raise ValueError(f"ds_ratio 为 'pad_to_most'出错了")
assert d1_len == d2_len == d3_len == 300
# ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
d1_len, d2_len, d3_len = 0, 0, 0
for idx, batch in enumerate(dl8):
for item in batch['y'].tolist():
if item in [1, 0, 1]:
d1_len += 1
elif item in [20, 10, 10]:
d2_len += 1
elif item in [100, 100, 200]:
d3_len += 1
if idx >= 44:
raise ValueError(f"ds_ratio 为 'Dict'出错了")
assert d1_len == 30
assert d2_len == 60
assert d3_len == 600
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
d1_len, d2_len, d3_len = 0, 0, 0
for idx, batch in enumerate(dl9):
for item in batch['y'].tolist():
if item in [1, 0, 1]:
d1_len += 1
elif item in [20, 10, 10]:
d2_len += 1
elif item in [100, 100, 200]:
d3_len += 1
if idx >= 21:
raise ValueError(f"ds_ratio 为 'Dict'出错了")
def test_polling(self):
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18)
for idx, batch in enumerate(dl):
if idx == 0 or idx == 3:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
# collate_fn = Callable
def collate_batch(batch):
new_batch = {'x': [], 'y': []}
for ins in batch:
new_batch['x'].append(ins['x'])
new_batch['y'].append(ins['y'])
return new_batch
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18)
for idx, batch in enumerate(dl1):
if idx == 0 or idx == 3:
assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]]
elif idx == 1 or idx == 4:
# d2
assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]]
elif idx == 2 or 4 < idx <= 20:
assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]]
if idx > 20:
raise ValueError(f"out of range")
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
'd2': Collator(backend='auto').set_pad("x", -2),
'd3': Collator(backend='auto').set_pad("x", -3)}
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
for idx, batch in enumerate(dl1):
if idx == 0 or idx == 3:
assert test_pad_val(batch['x'], val=-1)
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert test_pad_val(batch['x'], val=-2)
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert test_pad_val(batch['x'], val=-3)
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
# sampler 为 str
dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18)
dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18)
for idx, batch in enumerate(dl2):
if idx == 0 or idx == 3:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
for idx, batch in enumerate(dl3):
if idx == 0 or idx == 3:
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2),
'd3': RandomSampler(d3)}
dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18)
for idx, batch in enumerate(dl4):
if idx == 0 or idx == 3:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
# ds_ratio 为 'truncate_to_least'
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
for idx, batch in enumerate(dl5):
if idx == 0 or idx == 3:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx == 2 or idx == 5:
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 5:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
# ds_ratio 为 'pad_to_most'
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
for idx, batch in enumerate(dl6):
if idx % 3 == 0:
# d1
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
if idx % 3 == 1:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
if idx % 3 == 2:
# d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx >= 51:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
# ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
for idx, batch in enumerate(dl7):
if idx == 0 or idx == 3:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4 or idx == 6 or idx == 8:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx == 2 or idx == 5 or idx == 7 or idx > 8:
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 39:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
for idx, batch in enumerate(dl8):
if idx == 0:
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1:
# d2
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
assert batch['x'].shape[1] == 3
elif idx > 1:
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 18:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)