mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
增加 mixdataloader 文档, 修改mix_sampler, mixdataloader代码, 增加相应测试用例
This commit is contained in:
parent
916c113322
commit
9bb1ed4ccf
@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet
|
||||
class _JittorDataset(Dataset):
|
||||
"""
|
||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset) -> None:
|
||||
@ -37,7 +37,7 @@ class _JittorDataset(Dataset):
|
||||
item = item.tolist()
|
||||
return (item, self.dataset[item])
|
||||
|
||||
|
||||
|
||||
class JittorDataLoader:
|
||||
"""
|
||||
提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
|
||||
|
@ -2,13 +2,14 @@ __all__ = [
|
||||
'MixDataLoader'
|
||||
]
|
||||
|
||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence
|
||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.dataset import DataSet, Instance
|
||||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.core.collators import Collator
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.utils.data import DataLoader, Sampler
|
||||
@ -18,12 +19,13 @@ else:
|
||||
|
||||
class _MixDataset:
|
||||
"""
|
||||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx
|
||||
将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, datasets: list = None) -> None:
|
||||
"""
|
||||
:param datasets: 数据集的列表
|
||||
:param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列
|
||||
"""
|
||||
self.datasets = datasets
|
||||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置
|
||||
@ -35,7 +37,7 @@ class _MixDataset:
|
||||
|
||||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]:
|
||||
"""
|
||||
根据index索引获取数据
|
||||
根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回
|
||||
|
||||
:param idx: 整数类型的index或者列表
|
||||
:return:
|
||||
@ -69,8 +71,9 @@ class _MixCollateFn:
|
||||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题
|
||||
|
||||
"""
|
||||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None,
|
||||
auto_collators: Optional[List[Callable]] = None) -> None:
|
||||
|
||||
def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None:
|
||||
|
||||
if isinstance(collate_fns, Sequence):
|
||||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst)
|
||||
elif callable(collate_fns):
|
||||
@ -78,96 +81,124 @@ class _MixCollateFn:
|
||||
else:
|
||||
self.collate_fns = lambda idx, lst: lst
|
||||
|
||||
self.collate_fns = collate_fns
|
||||
self.auto_collators = auto_collators
|
||||
|
||||
def __call__(self, ins_list: List) -> Dict:
|
||||
"""
|
||||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种
|
||||
|
||||
:param ins_list:
|
||||
:return:
|
||||
"""
|
||||
_ins_list, _ds_index = [], 0
|
||||
for ins, _ds_index in ins_list:
|
||||
_ins_list.append(ins)
|
||||
# auto_collate先处理
|
||||
if self.auto_collators is not None:
|
||||
_ins_list = self.auto_collators[_ds_index](_ins_list)
|
||||
_ins_list = self.collate_fns(_ds_index, _ins_list)
|
||||
return _ins_list
|
||||
|
||||
|
||||
class MixDataLoader(DataLoader):
|
||||
"""
|
||||
针对一下三种情况提供的MixDataLoader:
|
||||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。
|
||||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。
|
||||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset
|
||||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。
|
||||
针对一下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``:
|
||||
|
||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个
|
||||
接一个的 sample 完所有数据。
|
||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample
|
||||
混合数据集 datasets 的数据组成一个 batch 序列返回。
|
||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回,
|
||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。
|
||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int];
|
||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数
|
||||
sampler, drop_last, ds_ratio 均无效。
|
||||
|
||||
"""
|
||||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential',
|
||||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None,
|
||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None,
|
||||
|
||||
def __init__(self, datasets: Dict = None, mode: Union[str, "Sampler"] = 'sequential',
|
||||
collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto',
|
||||
sampler: Union[Dict[str, "Sampler"], str, None] = None,
|
||||
num_workers: int = 0, batch_size: int = 16, drop_last=False,
|
||||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None,
|
||||
pin_memory: bool = True) -> None:
|
||||
ds_ratio: Union[None, str, Dict[str, float]] = None,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
|
||||
:param datasets: dataset的列表
|
||||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况,
|
||||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代
|
||||
对象,每次迭代返回的是List[int]
|
||||
:param collate_fn: 对取得到的数据进行打包的callable函数,
|
||||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数;
|
||||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数,
|
||||
其对应关系与datasets位置匹配;
|
||||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。
|
||||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象
|
||||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样;
|
||||
sampler为List[Sampler],则datasets也为List,且一一对应
|
||||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。
|
||||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程
|
||||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致
|
||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据
|
||||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充
|
||||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度
|
||||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止
|
||||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数,
|
||||
其大于0,可以超过1,将数据集重采样翻倍即可
|
||||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。
|
||||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。
|
||||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``:
|
||||
|
||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个
|
||||
接一个的 sample 完所有数据。
|
||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample
|
||||
混合数据集 datasets 的数据组成一个 batch 序列返回。
|
||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回,
|
||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。
|
||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int];
|
||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数
|
||||
sampler, drop_last, ds_ratio 均无效。
|
||||
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``:
|
||||
|
||||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class: `~fastNLP.core.collators.Collator` 作为其默认值,
|
||||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn .
|
||||
* collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入,
|
||||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]``
|
||||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。
|
||||
|
||||
|
||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为
|
||||
``[None, str, Dict[str, "Sampler"]]``:
|
||||
|
||||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。
|
||||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler``
|
||||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。
|
||||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。
|
||||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler] `` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。
|
||||
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时
|
||||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``:
|
||||
|
||||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。
|
||||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断
|
||||
到最短长度``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``。
|
||||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充
|
||||
到最大长度``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。
|
||||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数,
|
||||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。
|
||||
"""
|
||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
|
||||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \
|
||||
isinstance(sampler, Dict):
|
||||
raise ValueError(f"")
|
||||
# sampler 为 dict,则判断是否与 datasets 的 key 相同
|
||||
if isinstance(sampler, Dict):
|
||||
for key in datasets.keys():
|
||||
if not sampler[key]:
|
||||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
|
||||
# collate_fn 为 dict,则判断是否与 datasets 的 key 相同
|
||||
if isinstance(collate_fn, Dict):
|
||||
if mode == 'mix':
|
||||
raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'")
|
||||
for key in datasets.keys():
|
||||
if not collate_fn[key]:
|
||||
raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!")
|
||||
|
||||
if isinstance(collate_fn, list):
|
||||
if len(collate_fn) != len(datasets):
|
||||
raise ValueError("the length of collate_fn != datasets!!")
|
||||
if isinstance(collate_fn, str) and collate_fn == 'auto':
|
||||
date_type = None
|
||||
for idx, ds in enumerate(datasets.values()):
|
||||
if idx == 0:
|
||||
date_type = type(ds[0])
|
||||
if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)):
|
||||
raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。"
|
||||
f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}")
|
||||
|
||||
if isinstance(sampler, list):
|
||||
if len(sampler) != len(datasets):
|
||||
raise ValueError("the length of sampler != datasets!!")
|
||||
collate_fn = Collator(backend='torch')
|
||||
|
||||
# Dict类型转化为List,以便于_MixCollateFn处理
|
||||
# Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset
|
||||
if isinstance(collate_fn, Dict):
|
||||
collate_fn = [fn for _, fn in collate_fn.items()]
|
||||
|
||||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测
|
||||
if isinstance(datasets, Dict):
|
||||
dataset = [ds for _, ds in datasets.items()]
|
||||
else:
|
||||
dataset = datasets
|
||||
auto_collators = []
|
||||
for per_ds in dataset:
|
||||
if isinstance(per_ds, DataSet):
|
||||
auto_collators.append(per_ds.get_collator())
|
||||
else:
|
||||
# 如果没有对应的collator就设置一个不做任何操作的collator
|
||||
auto_collators.append(lambda x: x)
|
||||
dataset = [ds for _, ds in datasets.items()]
|
||||
|
||||
# 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题
|
||||
collate_fn = _MixCollateFn(collate_fn)
|
||||
|
||||
# List类型的collate_fn只有两种情况,需要对其进行包裹
|
||||
collate_fn = _MixCollateFn(collate_fn, auto_collators)
|
||||
if mode == 'sequential':
|
||||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler,
|
||||
drop_last=drop_last, ds_ratio=ds_ratio)
|
||||
|
@ -21,9 +21,9 @@ class MixSampler:
|
||||
mix_sampler的基类
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
|
||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None,
|
||||
ds_ratio: Union[str, List[float], Dict[str, float]] = None,
|
||||
def __init__(self, dataset: Dict, batch_size: int = None,
|
||||
sampler: Union[Dict[str, "Sampler"], None, str] = None,
|
||||
ds_ratio: Union[str, Dict[str, float]] = None,
|
||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
|
||||
"""
|
||||
|
||||
@ -32,9 +32,12 @@ class MixSampler:
|
||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
|
||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
|
||||
"""
|
||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
|
||||
if isinstance(dataset, Dict) and isinstance(sampler, List):
|
||||
raise ValueError(f"{sampler} must be dict")
|
||||
# sampler 为 dict,则判断是否与 datasets 的 key 相同
|
||||
if isinstance(sampler, Dict):
|
||||
for key in dataset.keys():
|
||||
if not sampler[key]:
|
||||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got batch_size={}".format(batch_size))
|
||||
@ -46,15 +49,7 @@ class MixSampler:
|
||||
raise ValueError("if rank>=0 and word_size>=0, sampler must be str")
|
||||
|
||||
if sampler is None and (word_size < 0 or rank < 0):
|
||||
if isinstance(dataset, List):
|
||||
self.sampler = [SequentialSampler(ds) for ds in dataset]
|
||||
elif isinstance(dataset, Dict):
|
||||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
|
||||
|
||||
elif isinstance(sampler, List):
|
||||
if len(sampler) != len(dataset):
|
||||
raise ValueError("the length of sampler != the length of sampler")
|
||||
self.sampler = sampler
|
||||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
|
||||
|
||||
elif isinstance(sampler, Dict):
|
||||
self.sampler = sampler
|
||||
@ -68,26 +63,7 @@ class MixSampler:
|
||||
|
||||
# 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len
|
||||
sampler_lens, total_lens, sampler_index = [], 0, []
|
||||
if isinstance(self.sampler, List):
|
||||
if ds_ratio is None:
|
||||
sampler_lens = [len(spl) for spl in self.sampler]
|
||||
|
||||
elif ds_ratio == 'pad_to_most':
|
||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif ds_ratio == 'truncate_to_least':
|
||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif isinstance(ds_ratio, List):
|
||||
if not all(item >= 0 for item in ds_ratio):
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got ds_ratio={}".format(ds_ratio))
|
||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)]
|
||||
else:
|
||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List")
|
||||
total_lens = sum(sampler_lens)
|
||||
|
||||
elif isinstance(self.sampler, Dict):
|
||||
if isinstance(self.sampler, Dict):
|
||||
if ds_ratio is None:
|
||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
|
||||
|
||||
@ -100,7 +76,7 @@ class MixSampler:
|
||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len
|
||||
|
||||
elif isinstance(ds_ratio, Dict):
|
||||
if not all(item >= 0 for item in ds_ratio):
|
||||
if not all([item >= 0 for item in ds_ratio.values()]):
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got ds_ratio={}".format(ds_ratio))
|
||||
sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()]
|
||||
@ -108,7 +84,7 @@ class MixSampler:
|
||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List")
|
||||
total_lens = sum(sampler_lens)
|
||||
|
||||
# sampler为str时候,初始化下移到iter方法中
|
||||
# sampler 为 str 时候,初始化下移到 iter 方法中
|
||||
if len(sampler_lens) > 0:
|
||||
sampler_index = [sampler_lens[0]]
|
||||
for idx in sampler_lens[1:]:
|
||||
@ -160,75 +136,37 @@ class DopedSampler(MixSampler):
|
||||
"""
|
||||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。
|
||||
"""
|
||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
|
||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None,
|
||||
ds_ratio: Union[str, None, List[float], Dict[str, float]] = None,
|
||||
def __init__(self, dataset: Dict, batch_size: int = None,
|
||||
sampler: Union[Dict[str, "Sampler"], str] = None,
|
||||
ds_ratio: Union[str, None, Dict[str, float]] = None,
|
||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None:
|
||||
super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size,
|
||||
sampler=sampler, ds_ratio=ds_ratio,
|
||||
drop_last=drop_last, rank=rank, word_size=word_size)
|
||||
|
||||
def __iter__(self) -> List[int]:
|
||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化
|
||||
# sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化
|
||||
if isinstance(self.sampler, str):
|
||||
if self.sampler == 'seq':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
elif self.sampler == 'rand':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(indices))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
|
||||
# 根据给定的ds_ratio计算真正需要处理数据集
|
||||
if isinstance(self.sampler, List):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for spl in self.sampler]
|
||||
|
||||
elif self.ds_ratio == 'pad_to_most':
|
||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif self.ds_ratio == 'truncate_to_least':
|
||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif isinstance(self.ds_ratio, List):
|
||||
if not all(item >= 0 for item in self.ds_ratio):
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got ds_ratio={}".format(self.ds_ratio))
|
||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
|
||||
else:
|
||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
|
||||
total_lens = sum(sampler_lens)
|
||||
|
||||
elif isinstance(self.sampler, Dict):
|
||||
if isinstance(self.sampler, Dict):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
|
||||
|
||||
@ -257,11 +195,11 @@ class DopedSampler(MixSampler):
|
||||
sampler_index.append(temp + idx)
|
||||
self.num_samplers = sampler_index
|
||||
self.len_samplers = total_lens
|
||||
# 每个batch的数据, 总的数据量total_index, 每个数据集的samplers
|
||||
# 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers
|
||||
batch_idx, samplers = [], []
|
||||
# 如果单机则用所有数据,否则采用多卡
|
||||
if self.rank < 0 or self.word_size < 0:
|
||||
# 根据sampler长度判断是否使用unsigned int 或者unsigned long
|
||||
# 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long
|
||||
if self.len_samplers > 42e8:
|
||||
total_index = array.array('L', list(range(self.len_samplers)))
|
||||
else:
|
||||
@ -274,15 +212,17 @@ class DopedSampler(MixSampler):
|
||||
else:
|
||||
total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size])
|
||||
|
||||
start_idx = 0
|
||||
|
||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
|
||||
for idx, (name, spl) in enumerate(self.sampler.items()):
|
||||
end_idx = len(spl)
|
||||
samplers.append((iter(spl), name, start_idx))
|
||||
start_idx += end_idx
|
||||
# 根据sampler的类型取出每个数据集的sampler
|
||||
if isinstance(self.sampler, List):
|
||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1]
|
||||
samplers = [(iter(spl), idx, base_index)
|
||||
for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))]
|
||||
else:
|
||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
|
||||
samplers = [(iter(spl), name, sampler_base_index[idx])
|
||||
for idx, (name, spl) in enumerate(self.sampler.items())]
|
||||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
|
||||
# samplers = [(iter(spl), name, sampler_base_index[idx])
|
||||
# for idx, (name, spl) in enumerate(self.sampler.items())]
|
||||
# 生成随机数
|
||||
np.random.seed(self.epoch)
|
||||
np.random.shuffle(total_index)
|
||||
@ -295,7 +235,7 @@ class DopedSampler(MixSampler):
|
||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration
|
||||
spl = iter(self.sampler[name])
|
||||
batch_idx.append(next(spl) + base_index)
|
||||
samplers[name] = (spl, name, base_index)
|
||||
samplers[ds_index] = (spl, name, base_index)
|
||||
if len(batch_idx) == self.batch_size:
|
||||
yield batch_idx
|
||||
batch_idx = []
|
||||
@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler):
|
||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化
|
||||
if isinstance(self.sampler, str):
|
||||
if self.sampler == 'seq':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
elif self.sampler == 'rand':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(indices))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
|
||||
# 根据给定的ds_ratio计算真正需要处理数据集
|
||||
if isinstance(self.sampler, List):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for spl in self.sampler]
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
|
||||
elif self.ds_ratio == 'pad_to_most':
|
||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif self.ds_ratio == 'truncate_to_least':
|
||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif isinstance(self.ds_ratio, List):
|
||||
if not all(item >= 0 for item in self.ds_ratio):
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got ds_ratio={}".format(self.ds_ratio))
|
||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
|
||||
else:
|
||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
|
||||
total_lens = sum(sampler_lens)
|
||||
|
||||
elif isinstance(self.sampler, Dict):
|
||||
# 根据给定的 ds_ratio 算真正需要处理数据集
|
||||
if isinstance(self.sampler, Dict):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
|
||||
|
||||
@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler):
|
||||
self.len_samplers = total_lens
|
||||
|
||||
batch_idx, total_index, samplers = [], list(range(self.len_samplers)), []
|
||||
if isinstance(self.sampler, List):
|
||||
if self.word_size > 0 and self.rank >= 0:
|
||||
sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1]
|
||||
else:
|
||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1]
|
||||
samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in
|
||||
enumerate(zip(self.sampler, sampler_base_index))]
|
||||
else:
|
||||
if self.word_size > 0 and self.rank >= 0:
|
||||
sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1]
|
||||
else:
|
||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
|
||||
start_idx = 0
|
||||
|
||||
samplers = [(iter(spl), name, sampler_base_index[idx])
|
||||
for idx, (name, spl) in enumerate(self.sampler.items())]
|
||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
|
||||
for idx, (name, spl) in enumerate(self.sampler.items()):
|
||||
end_idx = len(spl)
|
||||
samplers.append((iter(spl), name, start_idx))
|
||||
start_idx += end_idx
|
||||
# if self.word_size > 0 and self.rank >= 0:
|
||||
# sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1]
|
||||
# else:
|
||||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1]
|
||||
#
|
||||
# samplers = [(iter(spl), name, sampler_base_index[idx])
|
||||
# for idx, (name, spl) in enumerate(self.sampler.items())]
|
||||
for idx in total_index:
|
||||
ds_index = np.searchsorted(self.num_samplers, idx, side='right')
|
||||
|
||||
@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler):
|
||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration
|
||||
spl = iter(self.sampler[name])
|
||||
batch_idx.append(next(spl) + base_index)
|
||||
samplers[name] = (spl, name, base_index)
|
||||
samplers[ds_index] = (spl, name, base_index)
|
||||
if len(batch_idx) == self.batch_size:
|
||||
yield batch_idx
|
||||
batch_idx = []
|
||||
@ -506,63 +408,26 @@ class PollingSampler(MixSampler):
|
||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化
|
||||
if isinstance(self.sampler, str):
|
||||
if self.sampler == 'seq':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(list(range(len(per_ds))))
|
||||
elif self.sampler == 'rand':
|
||||
if isinstance(self.datasets, List):
|
||||
self.sampler = []
|
||||
for per_ds in self.datasets:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size]))
|
||||
else:
|
||||
self.sampler.append(InnerSampler(indices))
|
||||
elif isinstance(self.datasets, Dict):
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
|
||||
self.sampler = {}
|
||||
for name, per_ds in self.datasets.items():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(per_ds), generator=g).tolist()
|
||||
if self.word_size >= 0 and self.rank >= 0:
|
||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size])
|
||||
else:
|
||||
self.sampler[name] = InnerSampler(indices)
|
||||
|
||||
# 根据给定的ds_ratio计算真正需要处理数据集
|
||||
if isinstance(self.sampler, List):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for spl in self.sampler]
|
||||
|
||||
elif self.ds_ratio == 'pad_to_most':
|
||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif self.ds_ratio == 'truncate_to_least':
|
||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler)
|
||||
|
||||
elif isinstance(self.ds_ratio, List):
|
||||
if not all(item >= 0 for item in self.ds_ratio):
|
||||
raise ValueError("batch_size should be a positive integer value, "
|
||||
"but got ds_ratio={}".format(self.ds_ratio))
|
||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)]
|
||||
else:
|
||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List")
|
||||
total_lens = sum(sampler_lens)
|
||||
|
||||
elif isinstance(self.sampler, Dict):
|
||||
if isinstance(self.sampler, Dict):
|
||||
if self.ds_ratio is None:
|
||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()]
|
||||
|
||||
@ -592,17 +457,15 @@ class PollingSampler(MixSampler):
|
||||
self.num_samplers = sampler_index
|
||||
self.len_samplers = total_lens
|
||||
|
||||
start_idx, samplers = 0, []
|
||||
if isinstance(self.sampler, List):
|
||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
|
||||
for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)):
|
||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx))
|
||||
start_idx = end_idx
|
||||
else:
|
||||
for idx, (name, spl) in enumerate(self.sampler.items()):
|
||||
end_idx = self.num_samplers[idx]
|
||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name))
|
||||
start_idx = end_idx
|
||||
start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0
|
||||
|
||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标)
|
||||
for idx, (name, spl) in enumerate(self.sampler.items()):
|
||||
end_idx = len(spl)
|
||||
true_end_idx = self.num_samplers[idx]
|
||||
samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name))
|
||||
start_idx += end_idx
|
||||
true_start_idx = true_end_idx
|
||||
|
||||
while True:
|
||||
# 退出循环
|
||||
|
495
tests/core/dataloaders/test_mixdataloader.py
Normal file
495
tests/core/dataloaders/test_mixdataloader.py
Normal file
@ -0,0 +1,495 @@
|
||||
import pytest
|
||||
from typing import Mapping
|
||||
|
||||
from fastNLP.core.dataloaders import MixDataLoader
|
||||
from fastNLP import DataSet
|
||||
from fastNLP.core.collators import Collator
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch.utils.data import default_collate, SequentialSampler, RandomSampler
|
||||
|
||||
d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
|
||||
d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10})
|
||||
|
||||
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})
|
||||
|
||||
|
||||
def test_pad_val(tensor, val=0):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.tolist()
|
||||
for item in tensor:
|
||||
if item[-1] > 0:
|
||||
continue
|
||||
elif item[-1] != val:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestMixDataLoader:
|
||||
|
||||
def test_sequential_init(self):
|
||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
|
||||
# drop_last = True, collate_fn = 'auto
|
||||
dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True)
|
||||
for idx, batch in enumerate(dl):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx > 1:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
|
||||
# collate_fn = Callable
|
||||
def collate_batch(batch):
|
||||
new_batch = {'x': [], 'y': []}
|
||||
for ins in batch:
|
||||
new_batch['x'].append(ins['x'])
|
||||
new_batch['y'].append(ins['y'])
|
||||
return new_batch
|
||||
|
||||
dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True)
|
||||
for idx, batch in enumerate(dl1):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert [1, 2] in batch['x']
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert [101, 201] in batch['x']
|
||||
if idx > 1:
|
||||
# d3
|
||||
assert [1000, 2000] in batch['x']
|
||||
assert 'x' in batch and 'y' in batch
|
||||
|
||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
|
||||
'd2': Collator(backend='auto').set_pad("x", -2),
|
||||
'd3': Collator(backend='auto').set_pad("x", -3)}
|
||||
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
|
||||
for idx, batch in enumerate(dl2):
|
||||
if idx == 0:
|
||||
assert test_pad_val(batch['x'], val=-1)
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
assert test_pad_val(batch['x'], val=-2)
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx > 1:
|
||||
assert test_pad_val(batch['x'], val=-3)
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
|
||||
# sampler 为 str
|
||||
dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True)
|
||||
dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True)
|
||||
for idx, batch in enumerate(dl3):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx == 2:
|
||||
# d3
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
if idx > 1:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
|
||||
for idx, batch in enumerate(dl4):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx == 2:
|
||||
# d3
|
||||
assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
if idx > 1:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
|
||||
# sampler 为 Dict
|
||||
samplers = {'d1': SequentialSampler(d1),
|
||||
'd2': SequentialSampler(d2),
|
||||
'd3': RandomSampler(d3)}
|
||||
dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True)
|
||||
for idx, batch in enumerate(dl5):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx > 1:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
|
||||
# ds_ratio 为 'truncate_to_least'
|
||||
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
|
||||
for idx, batch in enumerate(dl6):
|
||||
if idx == 0:
|
||||
# d1
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if idx == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if idx == 2:
|
||||
# d3
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx > 2:
|
||||
raise ValueError(f"ds_ratio: 'truncate_to_least' error")
|
||||
|
||||
# ds_ratio 为 'pad_to_most'
|
||||
dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True)
|
||||
for idx, batch in enumerate(dl7):
|
||||
if idx < 18:
|
||||
# d1
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if 18 <= idx < 36:
|
||||
# d2
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if 36 <= idx < 54:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 54:
|
||||
raise ValueError(f"ds_ratio: 'pad_to_most' error")
|
||||
|
||||
# ds_ratio 为 Dict[str, float]
|
||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
|
||||
dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
|
||||
for idx, batch in enumerate(dl8):
|
||||
if idx < 1:
|
||||
# d1
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
if 1 <= idx < 4:
|
||||
# d2
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if 4 <= idx < 41:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 41:
|
||||
raise ValueError(f"ds_ratio: 'pad_to_most' error")
|
||||
|
||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
|
||||
dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
|
||||
for idx, batch in enumerate(dl9):
|
||||
if idx < 1:
|
||||
# d2
|
||||
assert batch['x'].shape == torch.Size([16, 3])
|
||||
if 1 <= idx < 19:
|
||||
# d3
|
||||
assert batch['x'].shape == torch.Size([16, 4])
|
||||
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 19:
|
||||
raise ValueError(f"ds_ratio: 'pad_to_most' error")
|
||||
|
||||
def test_mix(self):
|
||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
|
||||
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
|
||||
for idx, batch in enumerate(dl):
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 22:
|
||||
raise ValueError(f"out of range")
|
||||
|
||||
# collate_fn = Callable
|
||||
def collate_batch(batch):
|
||||
new_batch = {'x': [], 'y': []}
|
||||
for ins in batch:
|
||||
new_batch['x'].append(ins['x'])
|
||||
new_batch['y'].append(ins['y'])
|
||||
return new_batch
|
||||
|
||||
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
|
||||
for idx, batch in enumerate(dl1):
|
||||
assert isinstance(batch['x'], list)
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 22:
|
||||
raise ValueError(f"out of range")
|
||||
|
||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
|
||||
'd2': Collator(backend='auto').set_pad("x", -2),
|
||||
'd3': Collator(backend='auto').set_pad("x", -3)}
|
||||
with pytest.raises(ValueError):
|
||||
MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns)
|
||||
|
||||
# sampler 为 str
|
||||
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
|
||||
for idx, batch in enumerate(dl3):
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 22:
|
||||
raise ValueError(f"out of range")
|
||||
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
|
||||
for idx, batch in enumerate(dl4):
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 22:
|
||||
raise ValueError(f"out of range")
|
||||
# sampler 为 Dict
|
||||
samplers = {'d1': SequentialSampler(d1),
|
||||
'd2': SequentialSampler(d2),
|
||||
'd3': RandomSampler(d3)}
|
||||
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
|
||||
for idx, batch in enumerate(dl5):
|
||||
assert test_pad_val(batch['x'], val=0)
|
||||
if idx >= 22:
|
||||
raise ValueError(f"out of range")
|
||||
# ds_ratio 为 'truncate_to_least'
|
||||
dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least')
|
||||
d1_len, d2_len, d3_len = 0, 0, 0
|
||||
for idx, batch in enumerate(dl6):
|
||||
for item in batch['y'].tolist():
|
||||
if item in [1, 0, 1]:
|
||||
d1_len += 1
|
||||
elif item in [20, 10, 10]:
|
||||
d2_len += 1
|
||||
elif item in [100, 100, 200]:
|
||||
d3_len += 1
|
||||
if idx >= 6:
|
||||
raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了")
|
||||
assert d1_len == d2_len == d3_len == 30
|
||||
|
||||
# ds_ratio 为 'pad_to_most'
|
||||
dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most')
|
||||
d1_len, d2_len, d3_len = 0, 0, 0
|
||||
for idx, batch in enumerate(dl7):
|
||||
for item in batch['y'].tolist():
|
||||
if item in [1, 0, 1]:
|
||||
d1_len += 1
|
||||
elif item in [20, 10, 10]:
|
||||
d2_len += 1
|
||||
elif item in [100, 100, 200]:
|
||||
d3_len += 1
|
||||
|
||||
if idx >= 57:
|
||||
raise ValueError(f"ds_ratio 为 'pad_to_most'出错了")
|
||||
assert d1_len == d2_len == d3_len == 300
|
||||
|
||||
# ds_ratio 为 Dict[str, float]
|
||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
|
||||
dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
|
||||
d1_len, d2_len, d3_len = 0, 0, 0
|
||||
for idx, batch in enumerate(dl8):
|
||||
for item in batch['y'].tolist():
|
||||
if item in [1, 0, 1]:
|
||||
d1_len += 1
|
||||
elif item in [20, 10, 10]:
|
||||
d2_len += 1
|
||||
elif item in [100, 100, 200]:
|
||||
d3_len += 1
|
||||
if idx >= 44:
|
||||
raise ValueError(f"ds_ratio 为 'Dict'出错了")
|
||||
assert d1_len == 30
|
||||
assert d2_len == 60
|
||||
assert d3_len == 600
|
||||
|
||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
|
||||
dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
|
||||
d1_len, d2_len, d3_len = 0, 0, 0
|
||||
for idx, batch in enumerate(dl9):
|
||||
for item in batch['y'].tolist():
|
||||
if item in [1, 0, 1]:
|
||||
d1_len += 1
|
||||
elif item in [20, 10, 10]:
|
||||
d2_len += 1
|
||||
elif item in [100, 100, 200]:
|
||||
d3_len += 1
|
||||
if idx >= 21:
|
||||
raise ValueError(f"ds_ratio 为 'Dict'出错了")
|
||||
|
||||
def test_polling(self):
|
||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
|
||||
dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18)
|
||||
for idx, batch in enumerate(dl):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
|
||||
# collate_fn = Callable
|
||||
def collate_batch(batch):
|
||||
new_batch = {'x': [], 'y': []}
|
||||
for ins in batch:
|
||||
new_batch['x'].append(ins['x'])
|
||||
new_batch['y'].append(ins['y'])
|
||||
return new_batch
|
||||
|
||||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18)
|
||||
for idx, batch in enumerate(dl1):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]]
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]]
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]]
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
|
||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
|
||||
'd2': Collator(backend='auto').set_pad("x", -2),
|
||||
'd3': Collator(backend='auto').set_pad("x", -3)}
|
||||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
|
||||
for idx, batch in enumerate(dl1):
|
||||
if idx == 0 or idx == 3:
|
||||
assert test_pad_val(batch['x'], val=-1)
|
||||
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert test_pad_val(batch['x'], val=-2)
|
||||
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert test_pad_val(batch['x'], val=-3)
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
|
||||
# sampler 为 str
|
||||
dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18)
|
||||
dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18)
|
||||
for idx, batch in enumerate(dl2):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
for idx, batch in enumerate(dl3):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
# sampler 为 Dict
|
||||
samplers = {'d1': SequentialSampler(d1),
|
||||
'd2': SequentialSampler(d2),
|
||||
'd3': RandomSampler(d3)}
|
||||
dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18)
|
||||
for idx, batch in enumerate(dl4):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or 4 < idx <= 20:
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 20:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
|
||||
# ds_ratio 为 'truncate_to_least'
|
||||
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
|
||||
for idx, batch in enumerate(dl5):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or idx == 5:
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 5:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
|
||||
# ds_ratio 为 'pad_to_most'
|
||||
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
|
||||
for idx, batch in enumerate(dl6):
|
||||
if idx % 3 == 0:
|
||||
# d1
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx % 3 == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
if idx % 3 == 2:
|
||||
# d3
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx >= 51:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
|
||||
# ds_ratio 为 Dict[str, float]
|
||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
|
||||
dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
|
||||
for idx, batch in enumerate(dl7):
|
||||
if idx == 0 or idx == 3:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1 or idx == 4 or idx == 6 or idx == 8:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx == 2 or idx == 5 or idx == 7 or idx > 8:
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
if idx > 39:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
||||
|
||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
|
||||
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
|
||||
for idx, batch in enumerate(dl8):
|
||||
if idx == 0:
|
||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
elif idx == 1:
|
||||
# d2
|
||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
|
||||
assert batch['x'].shape[1] == 3
|
||||
elif idx > 1:
|
||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
|
||||
assert batch['x'].shape[1] == 4
|
||||
|
||||
if idx > 18:
|
||||
raise ValueError(f"out of range")
|
||||
test_pad_val(batch['x'], val=0)
|
Loading…
Reference in New Issue
Block a user