增加dataset, collator文档

This commit is contained in:
MorningForest 2022-05-22 19:12:47 +08:00
parent 7ca1abfba5
commit 724b112a34
8 changed files with 376 additions and 111 deletions

View File

@ -41,6 +41,11 @@ def is_jittor_tensor(dtype):
def is_jittor_dtype_str(dtype): def is_jittor_dtype_str(dtype):
"""
判断数据类型是否为 jittor 使用的字符串类型
:param: dtype 数据类型
"""
try: try:
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
@ -53,6 +58,13 @@ def is_jittor_dtype_str(dtype):
def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
"""
用于检测数据的 dtype 类型 根据内部和外部数据判断
:param ele_dtype 内部数据的类型
:param dtype 数据外部类型
:param class_name 类的名称
"""
if not (ele_dtype is None or ( if not (ele_dtype is None or (
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))): is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
@ -62,13 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name):
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)): if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or jittor.dtype but get `{dtype}`.") f"or jittor.dtype but get `{dtype}`.")
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype)
else: else:
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)):
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype)
# dtype = ele_dtype
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type)
if is_numpy_generic_class(ele_dtype): if is_numpy_generic_class(ele_dtype):
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype) dtype = numpy_to_jittor_dtype_dict.get(ele_dtype)
else: else:
@ -91,6 +97,11 @@ class JittorNumberPadder(Padder):
@staticmethod @staticmethod
def pad(batch_field, pad_val=0, dtype=None): def pad(batch_field, pad_val=0, dtype=None):
"""
:param batch_field 输入的某个 field batch 数据
:param pad_val 需要填充的值
:dtype 数据的类型
"""
return jittor.Var(np.array(batch_field, dtype=dtype)) return jittor.Var(np.array(batch_field, dtype=dtype))
@ -108,6 +119,11 @@ class JittorSequencePadder(Padder):
@staticmethod @staticmethod
def pad(batch_field, pad_val=0, dtype=None): def pad(batch_field, pad_val=0, dtype=None):
"""
:param batch_field 输入的某个 field batch 数据
:param pad_val 需要填充的值
:dtype 数据的类型
"""
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor return tensor
@ -126,6 +142,13 @@ class JittorTensorPadder(Padder):
@staticmethod @staticmethod
def pad(batch_field, pad_val=0, dtype=None): def pad(batch_field, pad_val=0, dtype=None):
"""
batch_field 数据 转为 jittor.Var pad 到相同长度
:param batch_field 输入的某个 field batch 数据
:param pad_val 需要填充的值
:dtype 数据的类型
"""
try: try:
if not isinstance(batch_field[0], jittor.Var): if not isinstance(batch_field[0], jittor.Var):
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field]
@ -139,9 +162,6 @@ class JittorTensorPadder(Padder):
else: else:
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
# if dtype is not None:
# tensor = jittor.full(max_shape, pad_val, dtype=dtype)
# else:
tensor = jittor.full(max_shape, pad_val, dtype=dtype) tensor = jittor.full(max_shape, pad_val, dtype=dtype)
for i, field in enumerate(batch_field): for i, field in enumerate(batch_field):
slices = (i,) + tuple(slice(0, s) for s in shapes[i]) slices = (i,) + tuple(slice(0, s) for s in shapes[i])

View File

@ -15,6 +15,13 @@ from .exceptions import *
def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
"""
用于检测数据的 dtype 类型 根据内部和外部数据判断
:param ele_dtype 内部数据的类型
:param dtype 数据外部类型
:param class_name 类的名称
"""
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.") f"or numpy numbers but get `{ele_dtype}`.")

View File

@ -36,6 +36,11 @@ from .exceptions import *
def is_paddle_tensor(dtype): def is_paddle_tensor(dtype):
"""
判断 dtype 是否为 paddle tensor
:param dtype 数据的 dtype 类型
"""
if not isclass(dtype) and isinstance(dtype, paddle.dtype): if not isclass(dtype) and isinstance(dtype, paddle.dtype):
return True return True
@ -43,6 +48,12 @@ def is_paddle_tensor(dtype):
def is_paddle_dtype_str(dtype): def is_paddle_dtype_str(dtype):
"""
判断 dtype str 类型 且属于 paddle 支持的 str 类型
:param dtype 数据的 dtype 类型
"""
try: try:
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', 'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
@ -56,6 +67,13 @@ def is_paddle_dtype_str(dtype):
def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
"""
用于检测数据的 dtype 类型 根据内部和外部数据判断
:param ele_dtype 内部数据的类型
:param dtype 数据外部类型
:param class_name 类的名称
"""
if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")

View File

@ -10,6 +10,13 @@ from .exceptions import *
def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
"""
用于检测数据的 dtype 类型 根据内部和外部数据判断
:param ele_dtype 内部数据的类型
:param dtype 数据外部类型
:param class_name 类的名称
"""
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.") f"or numpy numbers but get `{ele_dtype}`.")

View File

@ -35,12 +35,24 @@ from .exceptions import *
def is_torch_tensor(dtype): def is_torch_tensor(dtype):
"""
判断是否为 torch tensor
:param dtype 数据的 dtype 类型
"""
if not isclass(dtype) and isinstance(dtype, torch.dtype): if not isclass(dtype) and isinstance(dtype, torch.dtype):
return True return True
return False return False
def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
"""
用于检测数据的 dtype 类型 根据内部和外部数据判断
:param ele_dtype 内部数据的类型
:param dtype 数据外部类型
:param class_name 类的名称
"""
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

View File

@ -1,4 +1,150 @@
r""" r"""
:class:`~fastNLP.core.dataset.DataSet` fastNLP 中用于承载数据的容器可以将 DataSet 看做是一个表格
每一行是一个 sample ( fastNLP 中被称为 :mod:`~fastNLP.core.instance` )
每一列是一个 feature ( fastNLP 中称为 :mod:`~fastNLP.core.field` )
.. csv-table:: Following is a demo layout of DataSet
:header: "sentence", "words", "seq_len"
"This is the first instance .", "[This, is, the, first, instance, .]", 6
"Second instance .", "[Second, instance, .]", 3
"Third instance .", "[Third, instance, .]", 3
"...", "[...]", "..."
fastNLP 内部每一行是一个 :class:`~fastNLP.Instance` 对象 每一列是一个 :class:`~fastNLP.FieldArray` 对象
----------------------------
1.DataSet的创建
----------------------------
创建DataSet主要有以下的3种方式
1.1 传入dict
----------------------------
.. code-block::
from fastNLP import DataSet
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."],
'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'],
'seq_len': [6, 3, 3]}
dataset = DataSet(data)
# 传入的 dict 的每个 key 的 value 应该为具有相同长度的l ist
1.2 通过 Instance 构建
----------------------------
.. code-block::
from fastNLP import DataSet
from fastNLP import Instance
dataset = DataSet()
instance = Instance(sentence="This is the first instance",
words=['this', 'is', 'the', 'first', 'instance', '.'],
seq_len=6)
dataset.append(instance)
# 可以继续 append 更多内容,但是 append 的 instance 应该和第一个 instance 拥有完全相同的 field
1.3 通过 List[Instance] 构建
--------------------------------------
.. code-block::
from fastNLP import DataSet
from fastNLP import Instance
instances = []
winstances.append(Instance(sentence="This is the first instance",
ords=['this', 'is', 'the', 'first', 'instance', '.'],
seq_len=6))
instances.append(Instance(sentence="Second instance .",
words=['Second', 'instance', '.'],
seq_len=3))
dataset = DataSet(instances)
--------------------------------------
2.DataSet 与预处理
--------------------------------------
常见的预处理有如下几种
2.1 从某个文本文件读取内容
--------------------------------------
.. code-block::
from fastNLP import DataSet
from fastNLP import Instance
dataset = DataSet()
filepath = 'some/text/file'
# 假设文件中每行内容如下(sentence label):
# This is a fantastic day positive
# The bad weather negative
# .....
with open(filepath, 'r') as f:
for line in f:
sent, label = line.strip().split('\t')
dataset.append(Instance(sentence=sent, label=label))
2.2 DataSet 中的内容处理
--------------------------------------
.. code-block::
from fastNLP import DataSet
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]}
dataset = DataSet(data)
# 将句子分成单词形式, 详见DataSet.apply()方法, 可以开启多进程来加快处理, 也可以更改展示的bar目前支持 ``['rich', 'tqdm', None]``,
# 详细内容可以见 :class: `~fastNLP.core.dataset.DataSet`, 需要注意的时匿名函数不支持多进程
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words',
progress_des='Main',progress_bar='rich')
# 或使用DataSet.apply_field()
dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words',
progress_des='Main',progress_bar='rich')
# 除了匿名函数,也可以定义函数传递进去
def get_words(instance):
sentence = instance['sentence']
words = sentence.split()
return words
dataset.apply(get_words, new_field_name='words' num_proc=2, progress_des='Main',progress_bar='rich')
2.3 删除DataSet的内容
--------------------------------------
.. code-block::
from fastNLP import DataSet
dataset = DataSet({'a': list(range(-5, 5))})
# 返回满足条件的 instance,并放入 DataSet 中
dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)
# 在 dataset 中删除满足条件的i nstance
dataset.drop(lambda ins:ins['a']<0) # dataset 的 instance数量减少
# 删除第 3 个 instance
dataset.delete_instance(2)
# 删除名为 'a' 的 field
dataset.delete_field('a')
2.4 遍历DataSet的内容
--------------------------------------
.. code-block::
for instance in dataset:
# do something
2.5 一些其它操作
--------------------------------------
.. code-block::
# 检查是否存在名为 'a' 的 field
dataset.has_field('a') # 或 ('a' in dataset)
# 将名为 'a' 的 field 改名为 'b'
dataset.rename_field('a', 'b')
# DataSet 的长度
len(dataset)
""" """
__all__ = [ __all__ = [
@ -42,9 +188,9 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, p
""" """
对数据集进行处理封装函数以便多进程使用 对数据集进行处理封装函数以便多进程使用
:param ds: 数据集 :param ds: 实现了 __getitem__() __len__() 的对象
:param _apply_field: 需要处理数据集的field_name :param _apply_field: 需要处理数据集的 field_name
:param func: 用户自定义的func :param func: 用户自定义的 func
:param desc: 进度条的描述字符 :param desc: 进度条的描述字符
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]` :param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:return: :return:
@ -76,9 +222,9 @@ def _multi_proc(ds, _apply_field, func, counter, queue):
""" """
对数据集进行处理封装函数以便多进程使用 对数据集进行处理封装函数以便多进程使用
:param ds: 数据集 :param ds: 实现了 __getitem__() __len__() 的对象
:param _apply_field: 需要处理数据集的field_name :param _apply_field: 需要处理数据集的 field_name
:param func: 用户自定义的func :param func: 用户自定义的 func
:param counter: 计数器 :param counter: 计数器
:param queue: 多进程时将结果输入到这个 queue :param queue: 多进程时将结果输入到这个 queue
:return: :return:
@ -111,9 +257,28 @@ class DataSet:
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None):
r""" r"""
初始化 ``DataSet``, fastNLP的 DataSet key-value 存储形式 目前支持两种初始化方式输入 data 分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]``
``Dict[str, List[Any]]``
* data ``List[:class: `~fastNLP.core.dataset.Instance`]`` , 每个 ``Instance`` field_name 需要保持一致
Instance 详见 :class: `~fastNLP.core.dataset.Instance`
* data ``Dict[str, List[Any]] 则每个 key value 应该为等长的 list 否则不同 field 的长度不一致
:param data: 初始化的内容 其只能为两种类型分别为 ``List[:class: `~fastNLP.core.dataset.Instance`]``
``Dict[str, List[Any]]``
* data ``List[:class: `~fastNLP.core.dataset.Instance`]`` , 每个 ``Instance`` field_name 需要保持一致
Instance 详见 :class: `~fastNLP.core.dataset.Instance`
* data ``Dict[str, List[Any]] 则每个 key value 应该为等长的 list 否则不同 field 的长度不一致
Example::
from fastNLP.core.dataset import DataSet, Instance
data = {'x': [[1, 0, 1], [0, 1, 1], 'y': [0, 1]}
data1 = [Instance(x=[1,0,1],y=0), Instance(x=[0,1,1],y=1)]
ds = DataSet(data)
ds = DataSet(data1)
:param data: 如果为dict类型则每个key的value应该为等长的list; 如果为list
每个元素应该为具有相同field的 :class:`~fastNLP.Instance`
""" """
self.field_arrays = {} self.field_arrays = {}
self._collator = Collator() self._collator = Collator()
@ -168,11 +333,27 @@ class DataSet:
return inner_iter_func() return inner_iter_func()
def __getitem__(self, idx: Union[int, slice, str, list]): def __getitem__(self, idx: Union[int, slice, str, list]):
r"""给定int的index返回一个Instance; 给定slice返回包含这个slice内容的新的DataSet。 r"""
DataSet 的内容 根据 idx 类型不同有不同的返回值 包括四种类型 ``[int, slice, str, list]``
:param idx: can be int or slice. * idx ``int`` idx 的值不能超过 ``DataSet`` 的长度, 会返回一个 ``Instance``, 详见
:return: If `idx` is int, return an Instance object. :class: `~fastNLP.core.dataset.Instance`
If `idx` is slice, return a DataSet object. * idx ``slice`` 会根据 slice 的内容创建一个新的 DataSet其包含 slice 所有内容并返回
* idx ``str`` idx DataSet field_name, 其会返回该 field_name 的所有内容 list 类型
* idx ``list`` idx list 内全为 int 数字 其会取出所有内容组成一个新的 DataSet 返回
Example::
from fastNLP.core.dataset import DataSet
ds = DataSet({'x': [[1, 0, 1], [0, 1, 1] * 100, 'y': [0, 1] * 100})
ins = ds[0]
sub_ds = ds[0:100]
sub_ds= ds[[1, 0, 3, 2, 1, 4]]
field = ds['x']
:param idx: 用户传入参数
:return:
""" """
if isinstance(idx, int): if isinstance(idx, int):
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
@ -230,9 +411,10 @@ class DataSet:
return self.__dict__ return self.__dict__
def __len__(self): def __len__(self):
r"""Fetch the length of the dataset. r"""
获取 DataSet 的长度
:return length: :return
""" """
if len(self.field_arrays) == 0: if len(self.field_arrays) == 0:
return 0 return 0
@ -244,9 +426,9 @@ class DataSet:
def append(self, instance: Instance) -> None: def append(self, instance: Instance) -> None:
r""" r"""
将一个instance对象append到DataSet后面 将一个 instance 对象 append DataSet 后面详见 :class: `~fastNLP.Instance`
:param ~fastNLP.Instance instance: DataSet不为空instance应该拥有和DataSet完全一样的field :param instance: DataSet 不为空 instance 应该拥有和 DataSet 完全一样的 field
""" """
if len(self.field_arrays) == 0: if len(self.field_arrays) == 0:
@ -269,10 +451,10 @@ class DataSet:
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None:
r""" r"""
fieldarray添加到DataSet中. fieldarray 添加到 DataSet .
:param str field_name: 新加入的field的名称 :param field_name: 新加入的 field 的名称
:param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容 :param fieldarray: 需要加入 DataSet field 的内容, 详见 :class: `~fastNLP.core.dataset.FieldArray`
:return: :return:
""" """
if not isinstance(fieldarray, FieldArray): if not isinstance(fieldarray, FieldArray):
@ -285,10 +467,10 @@ class DataSet:
def add_field(self, field_name: str, fields: list) -> None: def add_field(self, field_name: str, fields: list) -> None:
r""" r"""
新增一个field 需要注意的是fields的长度跟dataset长度一致 新增一个 field 需要注意的是 fields 的长度跟 DataSet 长度一致
:param str field_name: 新增的field的名称 :param field_name: 新增的 field 的名称
:param list fields: 需要新增的field的内容 :param fields: 需要新增的 field 的内容
""" """
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
@ -299,9 +481,9 @@ class DataSet:
def delete_instance(self, index: int): def delete_instance(self, index: int):
r""" r"""
删除第index个instance 删除第 ``index `` Instance
:param int index: 需要删除的instance的index序号从0开始 :param index: 需要删除的 instanc e的 index序号从 `0` 开始
""" """
assert isinstance(index, int), "Only integer supported." assert isinstance(index, int), "Only integer supported."
if len(self) <= index: if len(self) <= index:
@ -315,9 +497,9 @@ class DataSet:
def delete_field(self, field_name: str): def delete_field(self, field_name: str):
r""" r"""
删除名为field_name的field 删除名为 field_name field
:param str field_name: 需要删除的field的名称. :param field_name: 需要删除的 field 的名称.
""" """
if self.has_field(field_name): if self.has_field(field_name):
self.field_arrays.pop(field_name) self.field_arrays.pop(field_name)
@ -327,10 +509,10 @@ class DataSet:
def copy_field(self, field_name: str, new_field_name: str): def copy_field(self, field_name: str, new_field_name: str):
r""" r"""
深度copy名为field_name的field到new_field_name 深度 copy 名为 field_name field new_field_name
:param str field_name: 需要copy的field :param field_name: 需要 copy field
:param str new_field_name: copy生成的field名称 :param new_field_name: copy 生成的 field 名称
:return: self :return: self
""" """
if not self.has_field(field_name): if not self.has_field(field_name):
@ -342,10 +524,10 @@ class DataSet:
def has_field(self, field_name: str) -> bool: def has_field(self, field_name: str) -> bool:
r""" r"""
判断DataSet中是否有名为field_name这个field 判断 DataSet 中是否有名为 field_name 这个 field
:param str field_name: field的名称 :param field_name: field 的名称
:return bool: 表示是否有名为field_name这个field :return: 表示是否有名为 field_name 这个 field
""" """
if isinstance(field_name, str): if isinstance(field_name, str):
return field_name in self.field_arrays return field_name in self.field_arrays
@ -353,9 +535,9 @@ class DataSet:
def get_field(self, field_name: str) -> FieldArray: def get_field(self, field_name: str) -> FieldArray:
r""" r"""
获取field_name这个field 获取 field_name 这个 field
:param str field_name: field的名称 :param field_name: field 的名称
:return: :class:`~fastNLP.FieldArray` :return: :class:`~fastNLP.FieldArray`
""" """
if field_name not in self.field_arrays: if field_name not in self.field_arrays:
@ -364,34 +546,34 @@ class DataSet:
def get_all_fields(self) -> dict: def get_all_fields(self) -> dict:
r""" r"""
返回一个dictkey为field_name, value为对应的 :class:`~fastNLP.FieldArray` 返回一个 dictkey field_name, value为对应的 :class:`~fastNLP.FieldArray`
:return dict: 返回如上所述的字典 :return: 返回如上所述的字典
""" """
return self.field_arrays return self.field_arrays
def get_field_names(self) -> list: def get_field_names(self) -> list:
r""" r"""
返回一个list包含所有 field 的名字 返回一个 list包含所有 field 的名字
:return list: 返回如上所述的列表 :return: 返回如上所述的列表
""" """
return sorted(self.field_arrays.keys()) return sorted(self.field_arrays.keys())
def get_length(self) -> int: def get_length(self) -> int:
r""" r"""
获取DataSet的元素数量 获取 DataSet 的元素数量
:return: int: DataSet中Instance的个数 :return: DataSet Instance 的个数
""" """
return len(self) return len(self)
def rename_field(self, field_name: str, new_field_name: str): def rename_field(self, field_name: str, new_field_name: str):
r""" r"""
将某个field重新命名. 将某个 field 重新命名.
:param str field_name: 原来的field名称 :param field_name: 原来的 field 名称
:param str new_field_name: 修改为new_name :param new_field_name: 修改为 new_name
""" """
if field_name in self.field_arrays: if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
@ -627,10 +809,10 @@ class DataSet:
def add_seq_len(self, field_name: str, new_field_name='seq_len'): def add_seq_len(self, field_name: str, new_field_name='seq_len'):
r""" r"""
将使用len()直接对field_name中每个元素作用将其结果作为sequence length, 并放入seq_len这个field 将使用 len() 直接对 field_name 中每个元素作用将其结果作为 sequence length, 并放入 seq_len 这个 field
:param field_name: str. :param field_name: 需要处理的 field_name
:param new_field_name: str. 新的field_name :param new_field_name: str. 新的 field_name
:return: :return:
""" """
if self.has_field(field_name=field_name): if self.has_field(field_name=field_name):
@ -641,10 +823,11 @@ class DataSet:
def drop(self, func: Callable, inplace=True): def drop(self, func: Callable, inplace=True):
r""" r"""
func接受一个Instance返回bool值返回值为True时该Instance会被移除或者不会包含在返回的DataSet中 删除某些 Instance 需要注意的时func 接受一个 Instance 返回 bool 返回值为 True
Instance 会被移除或者不会包含在返回的 DataSet
:param callable func: 接受一个Instance作为参数返回bool值True时删除该instance :param func: 接受一个 Instance 作为参数返回 bool True 时删除该 instance
:param bool inplace: 是否在当前DataSet中直接删除instance如果为False将返回一个新的DataSet :param inplace: 是否在当前 DataSet 中直接删除 instance如果为 False将返回一个新的 DataSet
:return: DataSet :return: DataSet
""" """
@ -663,10 +846,10 @@ class DataSet:
def split(self, ratio: float, shuffle=True): def split(self, ratio: float, shuffle=True):
r""" r"""
DataSet按照ratio的比例拆分返回两个DataSet DataSet 按照 ratio 的比例拆分返回两个 DataSet
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据第二个DataSet拥有`(1-ratio)`这么多数据 :param ratio: 0<ratio<1, 返回的第一个 DataSet 拥有 `ratio` 这么多数据第二个 DataSet 拥有 `(1-ratio)` 这么多数据
:param bool shuffle: split前是否shuffle一下False返回的第一个dataset就是当前dataset中前`ratio`比例的数据 :param shuffle: split 前是否 shuffle 一下 False返回的第一个 dataset 就是当前 dataset 中前 `ratio` 比例的数据
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] :return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ]
""" """
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.'
@ -696,7 +879,7 @@ class DataSet:
r""" r"""
保存DataSet. 保存DataSet.
:param str path: 将DataSet存在哪个路径 :param path: 将DataSet存在哪个路径
""" """
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(self, f) pickle.dump(self, f)
@ -704,9 +887,9 @@ class DataSet:
@staticmethod @staticmethod
def load(path: str): def load(path: str):
r""" r"""
从保存的DataSet pickle文件的路径中读取DataSet 从保存的 DataSet pickle文件的路径中读取DataSet
:param str path: 从哪里读取DataSet :param path: 从哪里读取 DataSet
:return: 读取后的 :class:`~fastNLP.读取后的DataSet` :return: 读取后的 :class:`~fastNLP.读取后的DataSet`
""" """
with open(path, 'rb') as f: with open(path, 'rb') as f:
@ -716,16 +899,16 @@ class DataSet:
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet':
""" """
将当前dataset与输入的dataset结合成一个更大的dataset需要保证两个dataset都包含了相同的field结合后的dataset的input,target 将当前 dataset 与输入的 dataset 结合成一个更大的 dataset需要保证两个 dataset 都包含了相同的 field结合后的 dataset
以及collate_fn以当前dataset为准当dataset中包含的field多于当前的dataset则多余的field会被忽略若dataset中未包含所有 field_name _collator 以当前 dataset 为准 dataset 中包含的 field 多于当前的 dataset则多余的 field 会被忽略
当前dataset含有field则会报错 dataset 中未包含所有当前 dataset 含有 field则会报错
:param DataSet, dataset: 需要和当前dataset concat的dataset :param dataset: 需要和当前 dataset concat的 dataset
:param bool, inplace: 是否直接将dataset组合到当前dataset中 :param inplace: 是否直接将 dataset 组合到当前 dataset
:param dict, field_mapping: 当传入的dataset中的field名称和当前dataset不一致时需要通过field_mapping把输入的dataset中的 :param field_mapping: 当传入的 dataset 中的 field 名称和当前 dataset 不一致时需要通过 field_mapping 把输入的 dataset 中的
field名称映射到当前field. field_mapping为dict类型key为dataset中的field名称value是需要映射成的名称 field 名称映射到当前 field. field_mapping dict 类型key dataset 中的 field 名称value 是需要映射成的名称
:return: DataSet :return: :class: `~fastNLP.core.dataset.DataSet``
""" """
assert isinstance(dataset, DataSet), "Can only concat two datasets." assert isinstance(dataset, DataSet), "Can only concat two datasets."
@ -754,8 +937,8 @@ class DataSet:
@classmethod @classmethod
def from_pandas(cls, df): def from_pandas(cls, df):
""" """
pandas.DataFrame中读取数据转为Dataset ``pandas.DataFrame`` 中读取数据转为 DataSet
:param df: :param df: 使用 pandas 读取的数据
:return: :return:
""" """
df_dict = df.to_dict(orient='list') df_dict = df.to_dict(orient='list')
@ -763,7 +946,7 @@ class DataSet:
def to_pandas(self): def to_pandas(self):
""" """
dataset转为pandas.DataFrame类型的数据 DataSet 数据转为 ``pandas.DataFrame`` 类型的数据
:return: :return:
""" """
@ -773,9 +956,9 @@ class DataSet:
def to_csv(self, path: str): def to_csv(self, path: str):
""" """
dataset保存为csv文件 DataSet 保存为 csv 文件
:param path: :param path: 保存到路径
:return: :return:
""" """

View File

@ -16,6 +16,13 @@ import numpy as np
class FieldArray: class FieldArray:
def __init__(self, name: str, content): def __init__(self, name: str, content):
"""
初始化 FieldArray
:param name: 字符串的名称
:param content: 任意类型的数据
"""
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Empty fieldarray is not allowed.") raise RuntimeError("Empty fieldarray is not allowed.")
_content = content _content = content
@ -29,15 +36,17 @@ class FieldArray:
def append(self, val: Any) -> None: def append(self, val: Any) -> None:
r""" r"""
:param val: 把该val append到fieldarray :param val: 把该 val append fieldarray
:return: :return:
""" """
self.content.append(val) self.content.append(val)
def pop(self, index: int) -> None: def pop(self, index: int) -> None:
r""" r"""
删除该field中index处的元素 删除该 field index 处的元素
:param int index: 从0开始的数据下标
:param index: ``0`` 开始的数据下标
:return: :return:
""" """
self.content.pop(index) self.content.pop(index)
@ -51,10 +60,10 @@ class FieldArray:
def get(self, indices: Union[int, List[int]]): def get(self, indices: Union[int, List[int]]):
r""" r"""
根据给定的indices返回内容 根据给定的 indices 返回内容
:param int,List[int] indices: 获取indices对应的内容 :param indices: 获取 indices 对应的内容
:return: 根据给定的indices返回的内容可能是单个值或ndarray :return: 根据给定的 indices 返回的内容可能是单个值 ``ndarray``
""" """
if isinstance(indices, int): if isinstance(indices, int):
if indices == -1: if indices == -1:
@ -69,18 +78,18 @@ class FieldArray:
def __len__(self): def __len__(self):
r""" r"""
Returns the size of FieldArray. 返回长度
:return int length: :return length:
""" """
return len(self.content) return len(self.content)
def split(self, sep: str = None, inplace: bool = True): def split(self, sep: str = None, inplace: bool = True):
r""" r"""
依次对自身的元素使用.split()方法应该只有当本field的元素为str时该方法才有用 依次对自身的元素使用 ``.split()`` 方法应该只有当本 field 的元素为 ``str`` 该方法才有用
:param sep: 分割符如果为None则直接调用str.split() :param sep: 分割符如果为 ``None`` 则直接调用 ``str.split()``
:param inplace: 如果为True则将新生成值替换本field否则返回list :param inplace: 如果为 ``True``则将新生成值替换本 field否则返回 ``list``
:return: List[List[str]] or self :return: List[List[str]] or self
""" """
new_contents = [] new_contents = []
@ -94,10 +103,11 @@ class FieldArray:
def int(self, inplace: bool = True): def int(self, inplace: bool = True):
r""" r"""
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的) 将本 field 中的值调用 ``int(cell)``. 支持 field 中内容为以下两种情况:
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个listlist中的值会被依次转换) * ['1', '2', ...]( field 中每个值为 ``str`` )
* [['1', '2', ..], ['3', ..], ...]( field 中每个值为一个 ``list`` ``list`` 中的值会被依次转换)
:param inplace: 如果为True则将新生成值替换本field否则返回list :param inplace: 如果为 ``True``则将新生成值替换本 field否则返回 ``list``
:return: List[int], List[List[int]], self :return: List[int], List[List[int]], self
""" """
new_contents = [] new_contents = []
@ -114,10 +124,12 @@ class FieldArray:
def float(self, inplace=True): def float(self, inplace=True):
r""" r"""
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的) 将本 field 中的值调用 ``float(cell)``. 支持 field 中内容为以下两种情况:
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个listlist中的值会被依次转换)
:param inplace: 如果为True则将新生成值替换本field否则返回list * ['1', '2', ...]( field 中每个值为 ``str`` )
* [['1', '2', ..], ['3', ..], ...]( field 中每个值为一个 ``list````list`` 中的值会被依次转换)
:param inplace: 如果为 ``True``则将新生成值替换本 ``field``否则返回 ``list``
:return: :return:
""" """
new_contents = [] new_contents = []
@ -134,10 +146,12 @@ class FieldArray:
def bool(self, inplace=True): def bool(self, inplace=True):
r""" r"""
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的) 将本field中的值调用 ``bool(cell)``. 支持 field 中内容为以下两种情况
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个listlist中的值会被依次转换)
:param inplace: 如果为True则将新生成值替换本field否则返回list * ['1', '2', ...]( field 中每个值为 ``str`` )
* [['1', '2', ..], ['3', ..], ...]( field 中每个值为一个 ``list````list`` 中的值会被依次转换)
:param inplace: 如果为 ``True``则将新生成值替换本 ``field``否则返回 ``list``
:return: :return:
""" """
new_contents = [] new_contents = []
@ -155,10 +169,12 @@ class FieldArray:
def lower(self, inplace=True): def lower(self, inplace=True):
r""" r"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的) 将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个listlist中的值会被依次转换)
:param inplace: 如果为True则将新生成值替换本field否则返回list * ['1', '2', ...]( ``field`` 中每个值为 ``str`` )
* [['1', '2', ..], ['3', ..], ...]( field 中每个值为一个 ``list````list``中的值会被依次转换)
:param inplace: 如果为 ``True``则将新生成值替换本 field否则返回 ``list``
:return: List[int], List[List[int]], self :return: List[int], List[List[int]], self
""" """
new_contents = [] new_contents = []
@ -175,10 +191,12 @@ class FieldArray:
def upper(self, inplace=True): def upper(self, inplace=True):
r""" r"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的) 将本 field 中的值调用 ``cell.lower()``. 支持 field 中内容为以下两种情况
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个listlist中的值会被依次转换)
:param inplace: 如果为True则将新生成值替换本field否则返回list * ['1', '2', ...]( field 中每个值为 ``str`` )
* [['1', '2', ..], ['3', ..], ...]( field 中每个值为一个 ``list````list`` 中的值会被依次转换)
:param inplace: 如果为 ``True``则将新生成值替换本 field否则返回 ``list``
:return: List[int], List[List[int]], self :return: List[int], List[List[int]], self
""" """
new_contents = [] new_contents = []
@ -195,9 +213,9 @@ class FieldArray:
def value_count(self): def value_count(self):
r""" r"""
返回该field下不同value的数量多用于统计label数量 返回该 field 下不同 value的 数量多用于统计 label 数量
:return: Counter, key是labelvalue是出现次数 :return: Counter, key labelvalue 是出现次数
""" """
count = Counter() count = Counter()
@ -214,7 +232,7 @@ class FieldArray:
def _after_process(self, new_contents: list, inplace: bool): def _after_process(self, new_contents: list, inplace: bool):
r""" r"""
当调用处理函数之后决定是否要替换field 当调用处理函数之后决定是否要替换 field
:param new_contents: :param new_contents:
:param inplace: :param inplace:

View File

@ -1,5 +1,5 @@
r""" r"""
instance 模块实现了Instance 类在fastNLP中对应sample一个sample可以认为是一个Instance类型的对象 instance 模块实现了 Instance 类在 fastNLP 中对应 sample一个 sample 可以认为是一个 Instance 类型的对象
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset`
""" """
@ -27,16 +27,16 @@ class Instance(Mapping):
def add_field(self, field_name: str, field: any): def add_field(self, field_name: str, field: any):
r""" r"""
Instance中增加一个field Instance 中增加一个 field
:param str field_name: 新增field的名称 :param field_name: 新增 field 的名称
:param Any field: 新增field的内容 :param field: 新增 field 的内容
""" """
self.fields[field_name] = field self.fields[field_name] = field
def items(self): def items(self):
r""" r"""
返回一个迭代器迭代器返回两个内容第一个内容是field_name, 第二个内容是field_value 返回一个迭代器迭代器返回两个内容第一个内容是 field_name, 第二个内容是 field_value
:return: 一个迭代器 :return: 一个迭代器
""" """