mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
0838529873
@ -24,6 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet
|
||||
class _JittorDataset(Dataset):
|
||||
"""
|
||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset) -> None:
|
||||
@ -39,9 +40,15 @@ class _JittorDataset(Dataset):
|
||||
|
||||
class JittorDataLoader:
|
||||
"""
|
||||
提供给使用jittor框架的DataLoader函数,其能够自动检测数据的类型并判断是否能够pad,若能会自动pad数据,默认pad_val=0;
|
||||
用户可以调用set_pad方法来更改pad_val的值,也可以自定义针对某个field的callate_fn传入到set_field;若用户不想自动pad某个field,
|
||||
则可以调用set_ignore来忽略对某个field的检测和pad。值得注意的是JittorDataLoader输入dataset只要是实现了__getitem__和__len__方法即可。
|
||||
提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
|
||||
若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。
|
||||
具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]``三种取值。
|
||||
|
||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。
|
||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* callate_fn 为 ``None`` 时, ``JittorDataLoader`` 默认使用 Jittor DataLoader 自带的 collate_fn
|
||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
"""
|
||||
|
||||
@ -51,25 +58,27 @@ class JittorDataLoader:
|
||||
collate_fn: Union[None, str, Callable] = "auto") -> None:
|
||||
"""
|
||||
|
||||
:param dataset: 实现``__getitem__``和``__len__``的dataset
|
||||
:param batch_size: 批次大小
|
||||
:param shuffle: 是否打乱数据集
|
||||
:param drop_last: 是否去掉最后一个不符合``batch_size``的数据
|
||||
:param num_workers: 进程的数量,当``num_workers=0``时不开启多进程
|
||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。
|
||||
:param stop_grad:
|
||||
:param keep_numpy_array: 返回的数据是``np.array`类`型而不是``jittor.array``类型,默认为``False``
|
||||
:param endless: 是否让``JittorDataLoader``无限返回数据,也就是将dataset循环使用使得返回数据是没有限制的。默认为``False``.
|
||||
:param collate_fn: 用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``JittorDataLoader``会调用默认的`callate_batch`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``JittorDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``JittorDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。
|
||||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False``
|
||||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``.
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。
|
||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
"""
|
||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler
|
||||
# 将内部dataset批次设置为1
|
||||
@ -130,8 +139,8 @@ class JittorDataLoader:
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
|
||||
无意义。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param backend: 可选['raw', 'numpy', 'Jittor', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
Jittor.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
@ -192,45 +201,53 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Tr
|
||||
non_train_batch_size: int = 16) \
|
||||
-> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]:
|
||||
"""
|
||||
prepare_jittor_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``,具体如下:
|
||||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``JittorDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.JittorDataLoader`。
|
||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
|
||||
|
||||
* 当ds_or_db为Dataset时,prepare_jittor_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个
|
||||
JittorDataLoader并返回。
|
||||
* 当ds_or_db为FastNLP的DataBundle时,prepare_jittor_dataloader会遍历所有的dataset并根据其name实例化不同的JittorDataLoader,
|
||||
当name中包含'train'字符串时,prepare_jittor_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终根据name:JittorDataLoader组成一个Dict[name, JittorDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_jittor_dataloader会遍历所有的dataset并根据其name实例化不同的JittorDataLoader,
|
||||
当name中包含'train'字符串时,prepare_jittor_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终根据name:JittorDataLoader组成一个Dict[name, JittorDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Sequence[Dataset]数据类型时, prepare_jittor_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为
|
||||
其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终
|
||||
将所有JittorDataLoader组成Sequence[JittorDataLoader]返回。
|
||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_jittor_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
|
||||
帮你实例化一个 ``JittorDataLoader`` 对象并返回该对象。 详见:class: `~fastNLP.core.dataloaders.JittorDataLoader`。
|
||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_Jittor_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
|
||||
来创建不同的 ``JittorDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_jittor_dataloader`` 默认该 value 为 train 数据集,
|
||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终根据 ``key: JittorDataLoader`` 组成 ``Dict[key, JittorDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_jittor_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
|
||||
``JittorDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_Jittor_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
|
||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: JittorDataLoader`` 组成
|
||||
``Dict[key, JittorDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_jittor_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
|
||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终将所有实例化好的 ``JittorDataLoader`` 组成 ``Sequence[JittorDataLoader]`` 返回。
|
||||
|
||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``.
|
||||
:param batch_size: batch 的大小。
|
||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
|
||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。
|
||||
:param shuffle: 是否打乱数据集
|
||||
:param drop_last: 是否去掉最后一个不符合``batch_size``的数据
|
||||
:param num_workers: 进程的数量,当``num_workers=0``时不开启多进程
|
||||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
|
||||
Sequence[DataSet], Dict[str, DataSet]]``.
|
||||
|
||||
* ds_or_db 为 :class: `~fastNLP.core.dataset.DataSet`,返回值为:class: `~fastNLP.core.dataloaders.JittorDataLoader`
|
||||
* ds_or_db 为 :class: `~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, JittorDataLoader]`` 的字典
|
||||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[JittorDataLoader]`` 的序列
|
||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, JittorDataLoader]`` 的字典
|
||||
|
||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict``, ``Sequence`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
|
||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 默认为 ``16``。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``JittorDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``JittorDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。
|
||||
:param stop_grad:
|
||||
:param keep_numpy_array: 返回的数据是``np.array`类`型而不是``jittor.array``类型,默认为``False``
|
||||
:param endless: 是否让``JittorDataLoader``无限返回数据,也就是将dataset循环使用使得返回数据是没有限制的。默认为``False``.
|
||||
:param collate_fn: 用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``JittorDataLoader``会调用默认的`callate_batch`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``JittorDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``JittorDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
:param stop_grad: 是否不使用梯度, 默认 ``True`` 。
|
||||
:param keep_numpy_array: 返回的数据是 ``np.array`` 类型而不是 ``jittor.Var`` 类型,默认为 ``False``
|
||||
:param endless: 是否让 ``JittorDataLoader`` 无限返回数据,也就是将 dataset 循环使用使得返回数据是没有限制的。默认为 ``False``.
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``JittorDataLoader`` 调用默认的 Jittor 框架的 ``DataLoader`` 自带的 ``collate_batch`` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的 dataset 对象。
|
||||
* callate_fn 为 ``'auto'`` 时,``JittorDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``JittorDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
:return: 返回数据类型为Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader其中之一,根据输入ds_or_db变化而变化。
|
||||
"""
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
|
@ -1,31 +1,3 @@
|
||||
"""
|
||||
``PaddleDataLoader``是专门提供给``paddle``框架的``DataLoader``,其集成了``fastNLP``的``Collator``并对``paddle``的``DataLoader``进行了
|
||||
封装,使得其具备以下功能:1.``PaddleDataLoader``支持输入的dataset是无框架的,只要实现了``__getitem__``和``__len__``方法即可,当不使用``fastNLP``的
|
||||
``DataSet``时候也能够自动检测数据的类型并进行padding,只需要将``collate_fn="auto"``即可,例如::
|
||||
|
||||
from fastNLP import PaddleDataLoader
|
||||
class MyDataset:
|
||||
def __init(self, data_lens=100):
|
||||
self.data_lens = 100
|
||||
def __getitem__(self, item):
|
||||
if item % 2 == 0:
|
||||
return {'x':[101, 256, 453], 'y': 0}
|
||||
else:
|
||||
return {'x': [101, 200], 'y': 1}
|
||||
def __len__(self):
|
||||
return self.data_lens
|
||||
dataset = MyDataset()
|
||||
paddle_dl = PaddleDataLoader(dataset, collate_fn="auto")
|
||||
for batch in paddle_dl:
|
||||
...
|
||||
|
||||
2.当设置``collate_fn="auto"``时,``PaddleDataLoader``会调用fastNLP的Collator对数据进行自动pad处理,此时可以调用``set_pad``和``set_ignore``方法
|
||||
来设置field的pad_val或者忽略某个field的pad操作。
|
||||
.. note::
|
||||
当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。
|
||||
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
'PaddleDataLoader',
|
||||
'prepare_paddle_dataloader'
|
||||
@ -74,60 +46,84 @@ class _PaddleDataset(Dataset):
|
||||
|
||||
class PaddleDataLoader(DataLoader):
|
||||
"""
|
||||
提供给``paddle``框架使用的``DataLoader``函数,``PaddleDataLoader``提供了``Collator``的功能,用户可以通过设置``collate_fn="auto"``来
|
||||
使用,并可以配套使用``set_pad``和``set_ignore``方法设置p``ad_val``和忽略某个field的pad操作。
|
||||
``PaddleDataLoader`` 是专门提供给 ``paddle`` 框架的 ``DataLoader`` ,其集成了 ``fastNLP`` 的 ``Collator`` ,
|
||||
具体详见 :class:`~fastNLP.core.collators.Collator`, 并对 ``paddle`` 的 ``DataLoader`` 进行了
|
||||
封装,使得其具备以下功能:1. ``PaddleDataLoader`` 支持输入的 dataset 是无框架的,只要实现了 __getitem__() 和 __len__() 的对象即可,
|
||||
当不使用 :class: `~fastNLP.core.dataset.DataSet` 时也不需要传入 collate_fn, 只要只需要将 ``collate_fn='auto'`` 就能够自动
|
||||
探测数据的类型并判断能否 pad .此时可以调用 ``set_pad`` 和 ``set_ignore`` 方法来设置 field 的 pad_val 或者忽略某个 field 的 pad 操作。
|
||||
Example::
|
||||
|
||||
from fastNLP import PaddleDataLoader
|
||||
class MyDataset:
|
||||
def __init(self, data_lens=100):
|
||||
self.data_lens = 100
|
||||
def __getitem__(self, item):
|
||||
if item % 2 == 0:
|
||||
return {'x':[101, 256, 453], 'y': 0}
|
||||
else:
|
||||
return {'x': [101, 200], 'y': 1}
|
||||
def __len__(self):
|
||||
return self.data_lens
|
||||
dataset = MyDataset()
|
||||
paddle_dl = PaddleDataLoader(dataset, collate_fn='auto')
|
||||
for batch in paddle_dl:
|
||||
...
|
||||
|
||||
2.当 collate_fn 为 ``None`` 时,``PaddleDataLoader`` 默认使用 ``paddle`` 自带的 ``default_collate_fn`` 作为 collate_fn 的值
|
||||
|
||||
.. note::
|
||||
当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。
|
||||
|
||||
3. 当 collate_fn 为 ``Callable`` 时,该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, feed_list=None, places=None,
|
||||
return_list: bool = True, batch_sampler=None,
|
||||
batch_size: int = 1, shuffle: bool = True,
|
||||
batch_size: int = 16, shuffle: bool = True,
|
||||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
|
||||
num_workers: int = 0, use_buffer_reader: bool = True,
|
||||
use_shared_memory: bool = True, timeout: int = 0,
|
||||
worker_init_fn: Callable = None, persistent_workers=False) -> None:
|
||||
"""
|
||||
|
||||
:param dataset: 实现了__getitem__和__len__的数据容器
|
||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。
|
||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list.
|
||||
The Tensors should be created by :code:`paddle.static.data()`.
|
||||
:attr:`feed_list` must be set if :attr:`return_list` is
|
||||
False. Default None.
|
||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): a list of Place,
|
||||
to put data onto, :attr:`places` can be None, if
|
||||
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
|
||||
will be used. Default None. If ``places`` is list of string,
|
||||
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
|
||||
where ``x`` is the index of the GPUs.
|
||||
:param return_list: whether the return value on each device is
|
||||
presented as a list. If :attr:`return_list=False`, the return
|
||||
value on each device would be a dict of str -> Tensor, where
|
||||
the key of the dict is the name of each fed Tensors. If
|
||||
:attr:`return_list=True`, the return value on each device would
|
||||
be a list(Tensor). :attr:`return_list` can only be True
|
||||
in dynamic graph mode. Default True.
|
||||
:param batch_sampler: 实现了``__iter__``和``__len__``方法的实例化对象,它的功能是根据dataset生成数据indices并组成一个batch数据。
|
||||
:param batch_size: dataloader每次获得数据的批次大小
|
||||
这个张量能被 :code:`paddle.static.data()` 创建。 如果:attr:`return_list` 是 ``False``, 那么 :attr:`feed_list`
|
||||
应该被设置。 默认为 ``None ``
|
||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): 将数据放进的一个 list 的 place。 :attr:`places` 能为 None.
|
||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串
|
||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。
|
||||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`,
|
||||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。
|
||||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` .
|
||||
默认值为 ``True`` 。
|
||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
|
||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。
|
||||
:param drop_last: 当``drop_last=True``时,``PaddleDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据;
|
||||
若``drop_last=False``, 则什么也不做。
|
||||
:param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``PaddleDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``PaddleDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``PaddleDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
|
||||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
:param num_workers: 开启多进程的数量,当``num_workers=0``时不开启多进程
|
||||
:param use_buffer_reader: 是否开启buffer_reader。如果``use_buffer_reader=True``,那么``PaddleDataLoader``将会异步的预取下一个batch的
|
||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是``True``。如果``use_buffer_reader=False``,那么什么也不错
|
||||
:param use_shared_memory: 是否使用共享内存。当``use_shared_memory=True``时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
|
||||
共享空间足够大时。(例如Linux上的/dev/shm/空间足够大)共享内存仅在多进程模式(num_workers>0)下生效。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 `use_buffer_reader=True`` ,那么 ``PaddleDataLoader` `会异步的预取下一个 batch 的
|
||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。
|
||||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
|
||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。
|
||||
:param timeout: 从子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。
|
||||
:param persistent_workers:
|
||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
|
||||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
|
||||
|
||||
"""
|
||||
# FastNLP Datset, collate_fn not None
|
||||
@ -195,8 +191,8 @@ class PaddleDataLoader(DataLoader):
|
||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
|
||||
无意义。
|
||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
|
||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'paddle', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
torch.Tensor, paddle.Tensor, paddle.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param backend: 可选['raw', 'numpy', 'Paddle', 'paddle', 'paddle', 'auto'],分别代表,输出为 list, numpy.ndarray,
|
||||
Paddle.Tensor, paddle.Tensor, paddle.Var 类型。若 pad_val 为 None ,该值无意义 。
|
||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
|
||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
|
||||
形式,输出将被直接作为结果输出。
|
||||
@ -261,69 +257,67 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
|
||||
non_train_batch_size: int = 16) \
|
||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
|
||||
"""
|
||||
prepare_paddle_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``,具体如下:
|
||||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.PaddleDataLoader`。
|
||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
|
||||
|
||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_paddle_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
|
||||
帮你实例化一个 ``PaddleDataLoader`` 对象并返回该对象。 详见:class: `~fastNLP.core.dataloaders.PaddleDataLoader`。
|
||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_paddle_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
|
||||
来创建不同的 ``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_Paddle_dataloader`` 默认该 value 为 train 数据集,
|
||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终根据 ``key: PaddleDataLoader`` 组成 ``Dict[key, PaddleDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_paddle_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
|
||||
``PaddleDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_paddle_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
|
||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: PaddleDataLoader`` 组成
|
||||
``Dict[key, PaddleDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_paddle_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
|
||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终将所有实例化好的 ``PaddleDataLoader`` 组成 ``Sequence[PaddleDataLoader]`` 返回。
|
||||
|
||||
::param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
|
||||
Sequence[DataSet], Dict[str, DataSet]]``.
|
||||
|
||||
* ds_or_db 为 :class: `~fastNLP.core.dataset.DataSet`,返回值为:class: `~fastNLP.core.dataloaders.PaddleDataLoader`
|
||||
* ds_or_db 为 :class: `~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, PaddleDataLoader]`` 的字典
|
||||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[PaddleDataLoader]`` 的序列
|
||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, PaddleDataLoader]`` 的字典
|
||||
|
||||
* 当ds_or_db为Dataset时,prepare_paddle_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个
|
||||
paddleDataLoader并返回。
|
||||
* 当ds_or_db为FastNLP的DataBundle时,prepare_paddle_dataloader会遍历所有的dataset并根据其name实例化不同的paddleDataLoader,
|
||||
当name中包含'train'字符串时,prepare_paddle_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终根据name:paddleDataLoader组成一个Dict[name, paddleDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_paddle_dataloader会遍历所有的dataset并根据其name实例化不同的paddleDataLoader,
|
||||
当name中包含'train'字符串时,prepare_paddle_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终根据name:paddleDataLoader组成一个Dict[name, paddleDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Sequence[Dataset]数据类型时, prepare_paddle_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为
|
||||
其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终
|
||||
将所有paddleDataLoader组成Sequence[paddleDataLoader]返回。
|
||||
|
||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``.
|
||||
:param batch_size: batch 的大小。
|
||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
|
||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。
|
||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list.
|
||||
The Tensors should be created by :code:`paddle.static.data()`.
|
||||
:attr:`feed_list` must be set if :attr:`return_list` is
|
||||
False. Default None.
|
||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): a list of Place,
|
||||
to put data onto, :attr:`places` can be None, if
|
||||
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
|
||||
will be used. Default None. If ``places`` is list of string,
|
||||
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
|
||||
where ``x`` is the index of the GPUs.
|
||||
:param return_list: whether the return value on each device is
|
||||
presented as a list. If :attr:`return_list=False`, the return
|
||||
value on each device would be a dict of str -> Tensor, where
|
||||
the key of the dict is the name of each fed Tensors. If
|
||||
:attr:`return_list=True`, the return value on each device would
|
||||
be a list(Tensor). :attr:`return_list` can only be True
|
||||
in dynamic graph mode. Default True.
|
||||
:param batch_sampler: 实现了``__iter__``和``__len__``方法的实例化对象,它的功能是根据dataset生成数据indices并组成一个batch数据。
|
||||
这个张量能被 :code:`paddle.static.data()` 创建。 如果:attr:`return_list` 是 ``False``, 那么 :attr:`feed_list`
|
||||
应该被设置。 默认为 ``None ``
|
||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): 将数据放进的一个 list 的 place。 :attr:`places` 能为 None.
|
||||
如果 :attr:`places` 为 None, 默认放在 CPUPlace 或者 CUDAPlace(0) 设备上。 如果 ``places`` 是一个 list 类型的 字符串, 那么字符串
|
||||
可以是 ``cpu`` , ``gpu:x`` 或者 ``gpu_pinned`` , 其中 ``x`` 是 gpu 的下标。
|
||||
:param return_list: 每个设备上的返回值是否为以列表形式显示。 如果 :attr:`return_list=False`,
|
||||
每个设备上的返回值值为 str -> Tensor 的 dict, 其中 dict 的 key 为每个 fed Tensors 的名字。
|
||||
如果 :attr:`return_list=True`, 每个设备上的返回值值为 list(Tensor)。 :attr:`return_list` 只能在动态图情况下设置为 ``True`` .
|
||||
默认值为 ``True`` 。
|
||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
|
||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, shuffle 参数均失效。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。
|
||||
:param drop_last: 当``drop_last=True``时,``PaddleDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据;
|
||||
若``drop_last=False``, 则什么也不做。
|
||||
:param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``PaddleDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``PaddleDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``PaddleDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``PaddleDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``PaddleDataLoader`` 调用默认的 Paddle 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
|
||||
* callate_fn 为 ``'auto'`` 时,``PaddleDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``PaddleDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
:param num_workers: 开启多进程的数量,当``num_workers=0``时不开启多进程
|
||||
:param use_buffer_reader: 是否开启buffer_reader。如果``use_buffer_reader=True``,那么``PaddleDataLoader``将会异步的预取下一个batch的
|
||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是``True``。如果``use_buffer_reader=False``,那么什么也不错
|
||||
:param use_shared_memory: 是否使用共享内存。当``use_shared_memory=True``时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
|
||||
共享空间足够大时。(例如Linux上的/dev/shm/空间足够大)共享内存仅在多进程模式(num_workers>0)下生效。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``PaddleDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param use_buffer_reader: 是否开启 buffer_reader 。如果 `use_buffer_reader=True`` ,那么 ``PaddleDataLoader` `会异步的预取下一个 batch 的
|
||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是 ``True``。
|
||||
:param use_shared_memory: 是否使用共享内存。当 ``use_shared_memory=True`` 时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的
|
||||
共享空间足够大时。(例如 Linux 上的 /dev/shm/ 空间足够大)共享内存仅在多进程模式( num_workers>0 )下生效。
|
||||
:param timeout: 从子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。
|
||||
:param persistent_workers:
|
||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
|
||||
:param persistent_workers: 如果其为 ``True``, ``PaddleDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
|
||||
|
||||
:return:
|
||||
"""
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
if isinstance(ds_or_db, Dataset):
|
||||
|
@ -21,8 +21,13 @@ else:
|
||||
|
||||
class _FDataSet:
|
||||
"""
|
||||
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
|
||||
中调用dataset的方法
|
||||
提供给 ``TorchDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回
|
||||
数据的下标 idx 。
|
||||
|
||||
..note::
|
||||
|
||||
需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset) -> None:
|
||||
@ -43,8 +48,16 @@ class _FDataSet:
|
||||
|
||||
class TorchDataLoader(DataLoader):
|
||||
"""
|
||||
提供给``torch``框架使用的``DataLoader``函数,``TorchDataLoader``提供了``Collator``的功能,用户可以通过设置``collate_fn="auto"``来
|
||||
使用,并可以配套使用``set_pad``和``set_ignore``方法设置p``ad_val``和忽略某个field的pad操作。
|
||||
提供给 ``torch`` 框架使用的 ``DataLoader`` 函数,``TorchDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad,
|
||||
若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。
|
||||
具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]``三种取值。
|
||||
|
||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。
|
||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn
|
||||
* collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size: int = 16,
|
||||
@ -57,32 +70,34 @@ class TorchDataLoader(DataLoader):
|
||||
persistent_workers: bool = False, **kwargs) -> None:
|
||||
"""
|
||||
|
||||
:param dataset: 实现了__getitem__和__len__的数据容器
|
||||
:param batch_size: 批次大小,当batch_sampler为None生效
|
||||
:param shuffle: 是否打乱数据集
|
||||
:param sampler: 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效
|
||||
:param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小,
|
||||
当其不为None时,bacth_size,sampler,shuffle均无效。
|
||||
:param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程
|
||||
:param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
:param dataset: 实现了 __getitem__() 和 __len__() 的对象。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。
|
||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
|
||||
默认为None, 当其不为 None 时, shuffle 参数无效。
|
||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
|
||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
* callate_fn 为 ``None`` 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 ``default_collate_fn`` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
|
||||
* callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
:param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。
|
||||
:param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据;
|
||||
若``drop_last=False``, 则什么也不做。
|
||||
:param timeout: 从子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。
|
||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param timeout: 子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
|
||||
:param multiprocessing_context: 多进程的上下文环境
|
||||
:param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed``
|
||||
:param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2.
|
||||
:param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False
|
||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed``
|
||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` .
|
||||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
|
||||
|
||||
"""
|
||||
if isinstance(dataset, DataSet) and collate_fn is None:
|
||||
@ -209,53 +224,62 @@ def prepare_torch_dataloader(ds_or_db,
|
||||
non_train_batch_size: int = 16) \
|
||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
|
||||
"""
|
||||
prepare_torch_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``,具体如下:
|
||||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.TorchDataLoader`。
|
||||
根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下:
|
||||
|
||||
* 当ds_or_db为Dataset时,prepare_torch_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个
|
||||
torchDataLoader并返回。
|
||||
* 当ds_or_db为FastNLP的DataBundle时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader,
|
||||
当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader,
|
||||
当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串
|
||||
的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader]
|
||||
的数据返回。
|
||||
* 当ds_or_db为Sequence[Dataset]数据类型时, prepare_torch_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为
|
||||
其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终
|
||||
将所有torchDataLoader组成Sequence[torchDataLoader]返回。
|
||||
* 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来
|
||||
帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class: `~fastNLP.core.dataloaders.TorchDataLoader`。
|
||||
* 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value
|
||||
来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,
|
||||
会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的
|
||||
``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数,
|
||||
其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成
|
||||
``Dict[key, TorchDataLoader]`` 的字典返回。
|
||||
* 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待,
|
||||
会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。
|
||||
最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。
|
||||
|
||||
:param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle,
|
||||
Sequence[Dataset], Dict[name, Dataset]]``.
|
||||
:param shuffle: 是否打乱数据集
|
||||
:param batch_size: batch 的大小。
|
||||
:param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数
|
||||
设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。
|
||||
:param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效
|
||||
:param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效
|
||||
:param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小,
|
||||
当其不为None时,bacth_size,sampler,shuffle均无效。
|
||||
:param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程
|
||||
:param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``.
|
||||
:param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle,
|
||||
Sequence[DataSet], Dict[str, DataSet]]``.
|
||||
|
||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``;
|
||||
第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。`
|
||||
* ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``,
|
||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field,
|
||||
可以调用``set_ignore``方法忽略某个field。
|
||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对
|
||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。
|
||||
* ds_or_db 为 :class: `~fastNLP.core.dataset.DataSet`,返回值为:class: `~fastNLP.core.dataloaders.TorchDataLoader`
|
||||
* ds_or_db 为 :class: `~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典
|
||||
* ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列
|
||||
* ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典
|
||||
|
||||
:param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。
|
||||
:param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据;
|
||||
若``drop_last=False``, 则什么也不做。
|
||||
:param timeout: 从子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。
|
||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。
|
||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。
|
||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
|
||||
默认为None, 当其不为 None 时, shuffle 参数无效。
|
||||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,
|
||||
默认为None, 当其不为 None 时, shuffle 参数无效。
|
||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为
|
||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。
|
||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快
|
||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。
|
||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``.
|
||||
|
||||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时,
|
||||
``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理
|
||||
:class:`~fastNLP.core.dataset.DataSet` 的dataset对象。
|
||||
* callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。
|
||||
此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。
|
||||
* `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是
|
||||
dataset 的一条数据;该 Callable 函数还应当返回一个对象。
|
||||
|
||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。
|
||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据;
|
||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。
|
||||
:param timeout: 子进程的输出队列获取数据的超时值
|
||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。
|
||||
:param multiprocessing_context: 多进程的上下文环境
|
||||
:param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed``
|
||||
:param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2.
|
||||
:param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False
|
||||
:param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed``
|
||||
:param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` .
|
||||
:param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False``
|
||||
|
||||
"""
|
||||
|
||||
from fastNLP.io import DataBundle
|
||||
|
@ -5,6 +5,7 @@ __all__ = [
|
||||
from typing import Union, List
|
||||
from collections import Counter
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from .metric import Metric
|
||||
from .backend import Backend
|
||||
@ -132,10 +133,10 @@ class ClassifyFPreRecMetric(Metric):
|
||||
seq_len = self.tensor2numpy(seq_len)
|
||||
|
||||
if seq_len is not None and target.ndim > 1:
|
||||
max_len = target.ndim[-1]
|
||||
max_len = target.shape[-1]
|
||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
|
||||
else:
|
||||
masks = None
|
||||
masks = np.ones_like(target)
|
||||
|
||||
if pred.ndim == target.ndim:
|
||||
if len(pred.flatten()) != len(target.flatten()):
|
||||
@ -143,7 +144,6 @@ class ClassifyFPreRecMetric(Metric):
|
||||
f" while target have element numbers:{len(pred.flatten())}, "
|
||||
f"pred have element numbers: {len(target.flatten())}")
|
||||
|
||||
pass
|
||||
elif pred.ndim == target.ndim + 1:
|
||||
pred = pred.argmax(axis=-1)
|
||||
if seq_len is None and target.ndim > 1:
|
||||
@ -152,11 +152,9 @@ class ClassifyFPreRecMetric(Metric):
|
||||
raise RuntimeError(f"when pred have "
|
||||
f"size:{pred.shape}, target should have size: {pred.shape} or "
|
||||
f"{pred.shape[:-1]}, got {target.shape}.")
|
||||
if masks is not None:
|
||||
target = target * masks
|
||||
pred = pred * masks
|
||||
|
||||
target_idxes = set(target.reshape(-1).tolist())
|
||||
for target_idx in target_idxes:
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target == target_idx) * masks).sum().item()
|
||||
|
@ -31,7 +31,7 @@ def _test(local_rank: int, world_size: int, device: "torch.device",
|
||||
|
||||
my_result = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
assert np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@ -69,7 +69,6 @@ class TestClassfiyFPreRecMetric:
|
||||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
|
||||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
|
||||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]])
|
||||
arg_max_pred = torch.argmax(pred, dim=-1)
|
||||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
|
||||
0, 3, 0, 0, 0, 1, 3, 1])
|
||||
|
||||
@ -79,10 +78,9 @@ class TestClassfiyFPreRecMetric:
|
||||
f1_score = 0.1882051282051282
|
||||
recall = 0.1619047619047619
|
||||
pre = 0.23928571428571427
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='micro')
|
||||
metric.update(pred, target)
|
||||
@ -93,7 +91,7 @@ class TestClassfiyFPreRecMetric:
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
|
||||
metric.update(pred, target)
|
||||
@ -103,19 +101,35 @@ class TestClassfiyFPreRecMetric:
|
||||
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5},
|
||||
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2},
|
||||
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9},
|
||||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9},
|
||||
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427,
|
||||
'recall': 0.1619047619047619, 'support': 32},
|
||||
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32},
|
||||
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875,
|
||||
'support': 32}}
|
||||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}}
|
||||
for keys in result_dict.keys():
|
||||
if keys == "f" or "pre" or "rec":
|
||||
continue
|
||||
gl = str(keys[-1])
|
||||
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"}
|
||||
gk = tmp_d[keys[0]]
|
||||
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
|
||||
|
||||
def test_seq_len(self):
|
||||
pred = torch.tensor([[[0.3, 0.7, 0.1], [0.4, 0.1, 0.1], [0.3, 0.1, 0.7]],
|
||||
[[0.7, 0.1, 0.1], [0.5, 0.9, 0.1], [0.3, 0.1, 0.7]]])
|
||||
seq_len = torch.LongTensor([3, 2])
|
||||
target = torch.LongTensor([[1, 0, 2], [0, 1, 0]])
|
||||
|
||||
# 不考虑长度
|
||||
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
assert result_dict[keys] != 1
|
||||
|
||||
# 考虑长度
|
||||
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
|
||||
metric.update(pred, target, seq_len=seq_len)
|
||||
result_dict = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
assert result_dict[keys] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("f_type, f1_score,recall,pre",
|
||||
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427),
|
||||
@ -180,3 +194,4 @@ class TestClassfiyFPreRecMetric:
|
||||
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
@ -226,7 +226,7 @@ class TestSpanFPreRecMetric:
|
||||
# print(expected_metric)
|
||||
metric_value = metric.get_metric()
|
||||
for key, value in expected_metric.items():
|
||||
np.allclose(value, metric_value[key])
|
||||
assert np.allclose(value, metric_value[key])
|
||||
|
||||
def test_auto_encoding_type_infer(self):
|
||||
# 检查是否可以自动check encode的类型
|
||||
|
Loading…
Reference in New Issue
Block a user