Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
yh 2022-05-23 19:43:35 +08:00
commit 388baf589c

View File

@ -997,3 +997,50 @@ class DataSet:
if self._collator is None:
self._collator = Collator()
return self._collator
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> Collator:
"""
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数 ``collator`` :class: `~fastNLP.core.collators.Collator`
时该函数才有效调用该函数可以对 field 内容的 pad_val, dtype, backend 等进行调整
:param field_name: 需要调整的 field 的名称如果 DataSet __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组表示多层次的 key例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素如果该 field 在数据中没
有找到则报错如果 __getitem__ 返回的是就是整体内容请使用 "_single"
:param pad_val: 这个 field 的默认 pad 如果设置为 None则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad
field 进行 pad所以如果对应 field 本身就不是可以 pad 的形式可以不需要主动设置为 None 如果 backend None 该值
无意义
:param dtype: 对于需要 pad field field 的数据 dtype 应该是什么
:param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto']分别代表输出为 list, numpy.ndarray,
torch.Tensor, torch.Tensor, jittor.Var 类型 pad_val None 该值无意义
:param pad_fn: 指定当前 field pad 函数传入该函数则 pad_val, dtype, backend 等参数失效pad_fn 的输入为当前 field
batch 形式 Collator 将自动 unbatch 数据然后将各个 field 组成各自的 batch pad_func 的输入即为 field batch
形式输出将被直接作为结果输出
:return: 返回 Collator
"""
if isinstance(self.collator, Collator):
self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self.collator
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> Collator:
"""
``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数 ``collator`` :class: `~fastNLP.core.collators.Collator`
时该函数才有效调用该函数可以设置忽略输出某些 field 的内容被设置的 field 将在 batch 的输出中被忽略
Example::
collator.set_ignore('field1', 'field2')
:param field_names: 需要忽略的 field 的名称如果 DataSet __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
field key 来表示如果是 nested dict可以使用元组来表示例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素
:return: 返回 Collator 自身
"""
if isinstance(self.collator, Collator):
self.collator.set_ignore(*field_names)
return self.collator
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")