From b9be645ed05f732ea66ef65dac2729bc6934de61 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Sun, 22 May 2022 21:12:42 +0800 Subject: [PATCH] =?UTF-8?q?dataset=20=E5=A2=9E=E5=8A=A0=20set=5Fpad,=20set?= =?UTF-8?q?=5Fignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset/dataset.py | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 025d33e5..eb470261 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -970,3 +970,47 @@ class DataSet: if self._collator is None: self._collator = Collator() return self._collator + + def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, + pad_fn: Callable = None) -> Collator: + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator + """ + if isinstance(self.collator, Collator): + self.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self.collator + else: + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") + + def set_ignore(self, *field_names) -> Collator: + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Example:: + + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 + """ + if isinstance(self.collator, Collator): + self.collator.set_ignore(*field_names) + return self.collator + else: + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") \ No newline at end of file