This commit is contained in:
yh_cc 2022-05-07 14:14:46 +08:00
commit f1fa665e7c
32 changed files with 585 additions and 201 deletions

View File

@ -146,11 +146,13 @@ class CallbackManager:
r""" r"""
用于断点重训的 callback 的保存函数 用于断点重训的 callback 的保存函数
该函数主要涉及两个方面 该函数主要涉及两个方面
1. callback 的状态的保存我们会调用每一个 callback `on_save_checkpoint` 方法该方法应当返回一个字典其中包含着
断点重训应当保存的状态
2. 每一个具体的 callback 函数的 filter 的状态
:return: 一个包含上述内容的字典:: 1. callback 的状态的保存我们会调用每一个 callback `on_save_checkpoint` 方法该方法应当返回一个字典其中包含着
断点重训应当保存的状态
2. 每一个具体的 callback 函数的 filter 的状态
:return: 一个包含上述内容的字典:
.. code-block::
{ {
"callback_name_1": { "callback_name_1": {
@ -158,6 +160,7 @@ class CallbackManager:
"filter_states": {"on_train_begin": filter1.state_dict(), ...} "filter_states": {"on_train_begin": filter1.state_dict(), ...}
} }
} }
""" """
states = {} states = {}

View File

@ -39,7 +39,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
意义是当检测到 Trainer evaluate results {watch_monitor} 的结果更好时则进行一次 evaluate 该参数有两种 意义是当检测到 Trainer evaluate results {watch_monitor} 的结果更好时则进行一次 evaluate 该参数有两种
取值: (1) str 类型监控的 metric 如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最 取值: (1) str 类型监控的 metric 如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最
匹配的那个作为 monitor ; (2) 也可以传入一个函数接受参数为 evaluation 的结果(字典类型)返回一个 float 值作为 monitor 匹配的那个作为 monitor ; (2) 也可以传入一个函数接受参数为 evaluation 的结果(字典类型)返回一个 float 值作为 monitor
的结果如果当前结果中没有相关的monitor 值请返回 None 的结果如果当前结果中没有相关的monitor 值请返回 None
:param watch_monitor_larger_better: watch_monitor 是否越大越好 :param watch_monitor_larger_better: watch_monitor 是否越大越好
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数例如是 `model.evaluate_step` 还是 :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数例如是 `model.evaluate_step` 还是
`model.forward`(1) 如果该值是 None那么我们会默认使用 `evaluate_step` 当做前向传播的函数如果在模型中没有 `model.forward`(1) 如果该值是 None那么我们会默认使用 `evaluate_step` 当做前向传播的函数如果在模型中没有

View File

@ -10,13 +10,13 @@ class TorchGradClipCallback(Callback):
在每次 optimizer update 之前将 parameter 进行 clip 在每次 optimizer update 之前将 parameter 进行 clip
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]clip_value应该为正数 :param float clip_value: 将gradient 限制到[-clip_value, clip_value]clip_value应该为正数
:param str clip_type: 支持'norm', 'value'两种:: :param str clip_type: 支持'norm', 'value'两种:
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] 1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
2. 'value', 将gradient限制在[-clip_value, clip_value],
小于-clip_value的gradient被赋值为-clip_value;
大于clip_value的gradient被赋值为clip_value.
2 'value', 将gradient限制在[-clip_value, clip_value],
小于-clip_value的gradient被赋值为-clip_value;
大于clip_value的gradient被赋值为clip_value.
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得 :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得
如果为None则默认对 Trainer optimizers 中所有参数进行梯度裁剪 如果为None则默认对 Trainer optimizers 中所有参数进行梯度裁剪
""" """

View File

@ -9,6 +9,7 @@ from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPad
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder
from .exceptions import * from .exceptions import *
@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'jittor':
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else: else:
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")
@ -103,6 +106,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'jittor':
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else: else:
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")
@ -116,6 +121,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'jittor':
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else: else:
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")

View File

@ -0,0 +1,195 @@
__all__ = [
'JittorNumberPadder',
'JittorSequencePadder',
'JittorTensorPadder'
]
from inspect import isclass
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor
numpy_to_jittor_dtype_dict = {
np.bool_: 'bool',
np.uint8: 'uint8',
np.int8: "int8",
np.int16: "int16",
np.int32: "int32",
np.int64: "int64",
np.float16: "float16",
np.float32: 'float32',
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
}
# number_to_jittor_dtype_dict = {
# float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64
# int: 'int64',
# bool: 'bool'
# }
from .padder import Padder
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
from .exceptions import *
def is_jittor_tensor(dtype):
if not isclass(dtype) and isinstance(dtype, jittor.jittor_core.Var):
return True
return False
def is_jittor_dtype_str(dtype):
try:
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
u'int16', u'int32', u'int64', u'uint8'}:
return True
except:
pass
return False
def _get_dtype(ele_dtype, dtype, class_name):
if not (ele_dtype is None or (
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or jittor.Var but get `{ele_dtype}`.")
if dtype is not None:
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or jittor.dtype but get `{dtype}`.")
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype)
else:
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)):
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype)
# dtype = ele_dtype
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type)
if is_numpy_generic_class(ele_dtype):
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype)
else:
dtype = ele_dtype
return dtype
class JittorNumberPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3])
:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型
:param dtype: 输出的数据的 dtype 是什么 jittor.long, jittor.float32, int, float
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@staticmethod
def pad(batch_field, pad_val, dtype):
return jittor.Var(np.array(batch_field, dtype=dtype))
class JittorSequencePadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据
:param pad_val: 需要 pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型
:param dtype: 输出的数据的 dtype 是什么 jittor.long, jittor.float32, int, float
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@staticmethod
def pad(batch_field, pad_val, dtype):
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor
class JittorTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的若内部元素不为 jittor.Var 则必须含有 tolist() 方法
:param pad_val: 需要 pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型
:param dtype: 输出的数据的 dtype 是什么 jittor.long, jittor.float32, int, float
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], jittor.Var):
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a jittor.Var (it is {type(batch_field[0])}), "
f"it must have tolist() method.")
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
# if dtype is not None:
# tensor = jittor.full(max_shape, pad_val, dtype=dtype)
# else:
tensor = jittor.full(max_shape, pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i,) + tuple(slice(0, s) for s in shapes[i])
tensor[slices] = field
return tensor
def fill_tensor(batch_field, padded_batch, dtype):
"""
batch_field 中的值填入到 tensor
:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 tensor
:param dtype: 数据的类别
:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = jittor.Var(np.array(content_i, dtype=dtype))
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype))
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = jittor.Var(np.array(content_iii, dtype=dtype))
elif padded_batch.ndim == 1:
padded_batch[:] = jittor.Var(np.array(batch_field, dtype=dtype))
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch
def get_padded_jittor_tensor(batch_field, dtype=None, pad_val=0):
"""
例如:
[[1,2], [3]] -> jittor.LongTensor([[1, 2], [3, 0]])
:param batch_field: 需要 pad 的对象需要保证应该是可以进行 pad 支持 1d多为句子长度/2d多为文本序列/3d多为字符序列
/4d多为图片
:param dtype: 目标类别是什么
:param pad_val: pad value
:return:
"""
shapes = get_shape(batch_field)
tensor = jittor.full(shapes, pad_val, dtype=dtype)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

View File

@ -51,23 +51,30 @@ class Evaluator:
False那么我们会将 batch 直接透传给 forward 函数注意上述逻辑同样应用于 `train_step`, `evaluate_step` `test_step` False那么我们会将 batch 直接透传给 forward 函数注意上述逻辑同样应用于 `train_step`, `evaluate_step` `test_step`
:param fp16: 是否使用 fp16 :param fp16: 是否使用 fp16
:param verbose: 是否打印 evaluate 的结果 :param verbose: 是否打印 evaluate 的结果
:param kwargs: :param \**kwargs:
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态 eval 状态下model 的dropout See below
batch normalization 将会关闭默认为True如果为 FalsefastNLP 不会对 model evaluate 状态做任何设置无论 :kwargs:
该值是什么fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train * *model_use_eval_mode* (``bool``) --
TODO 还没完成 是否在 evaluate 的时候将 model 的状态设置成 eval 状态 eval 状态下model
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 dropout batch normalization 将会关闭默认为True如果为 FalsefastNLP 不会对 model evaluate 状态做任何设置无论
tensor 适配到 metrics 支持的例如 model 输出是 paddlepaddle tensor 但是想利用 torchmetrics 的metric对象 该值是什么fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train
auto_tensor_conversion_for_metric 为True时fastNLP 将自动将输出中 paddle tensor 其它非 tensor 的参数 TODO 还没完成
不做任何处理转换为 pytorch tensor 再输入到 metrics 中进行评测 model 的输出 tensor 类型通过 driver 来决定 * *auto_tensor_conversion_for_metric* (``Union[bool]``) --
metrics 支持的输入类型由 metrics 决定如果需要更复杂的转换请使用 input_mappingoutput_mapping 参数进行 是否自动将输出中的 tensor 适配到 metrics 支持的例如 model 输出是
use_dist_sampler: 是否使用分布式evaluate的方式仅当 driver 为分布式类型时该参数才有效默认为根据 driver 是否支持 paddlepaddle tensor 但是想利用 torchmetrics 的metric对象 auto_tensor_conversion_for_metric 为True时fastNLP
分布式进行设置如果为True将使得每个进程上的 dataloader 自动使用不同数据所有进程的数据并集是整个数据集 自动将输出中 paddle tensor 其它非 tensor 的参数不做任何处理转换为 pytorch tensor 再输入到 metrics 中进行评测 model
output_from_new_proc: 应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一 输出 tensor 类型通过 driver 来决定metrics 支持的输入类型由 metrics 决定如果需要更复杂的转换
["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到 请使用 input_mappingoutput_mapping 参数进行
log 文件中然后将 log 文件保存在通过该参数值设定的文件夹中默认为 "only_error" * *use_dist_sampler* --
progress_bar: evaluate 的时候显示的 progress bar 目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 是否使用分布式evaluate的方式仅当 driver 为分布式类型时该参数才有效默认为根据 driver 是否支持
到当前terminal为交互型则使用 rich否则使用 raw 分布式进行设置如果为True将使得每个进程上的 dataloader 自动使用不同数据所有进程的数据并集是整个数据集
* *output_from_new_proc* --
应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一
["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到
log 文件中然后将 log 文件保存在通过该参数值设定的文件夹中默认为 "only_error"
* *progress_bar* --
evaluate 的时候显示的 progress bar 目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
到当前terminal为交互型则使用 rich否则使用 raw
""" """
self.model = model self.model = model

View File

@ -67,20 +67,21 @@ class Trainer(TrainerEventTrigger):
要自己实现模型部分而将训练层面的逻辑完全地交给 fastNLP 要自己实现模型部分而将训练层面的逻辑完全地交给 fastNLP
:param model: 训练所需要的模型目前支持 pytorch :param model: 训练所需要的模型目前支持 pytorch
:param driver: 训练模型所使用的具体的驱动模式应当为以下选择中的一个["torch", "torch_ddp", ]之后我们会加入 jittorpaddle :param driver: 训练模型所使用的具体的驱动模式应当为以下选择中的一个["torch", "torch_ddp", ]之后我们会加入 jittorpaddle
国产框架的训练模式其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 国产框架的训练模式其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
:param train_dataloader: 训练数据集注意其必须是单独的一个数据集不能是 List 或者 Dict :param train_dataloader: 训练数据集注意其必须是单独的一个数据集不能是 List 或者 Dict
:param optimizers: 训练所需要的优化器可以是单独的一个优化器实例也可以是多个优化器组成的 List :param optimizers: 训练所需要的优化器可以是单独的一个优化器实例也可以是多个优化器组成的 List
:param device: 该参数用来指定具体训练时使用的机器注意当该参数为 None fastNLP 不会将模型和数据进行设备之间的移动处理但是你 :param device: 该参数用来指定具体训练时使用的机器注意当该参数为 None fastNLP 不会将模型和数据进行设备之间的移动处理但是你
可以通过参数 `input_mapping` `output_mapping` 来实现设备之间数据迁移的工作通过这两个参数传入两个处理数据的函数同时你也 可以通过参数 `input_mapping` `output_mapping` 来实现设备之间数据迁移的工作通过这两个参数传入两个处理数据的函数同时你也
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上注意这种情况理应只出现在用户在 Trainer 实例化前 可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上注意这种情况理应只出现在用户在 Trainer 实例化前
自己构造 DDP 的多进程场景 自己构造 DDP 的多进程场景
device 的可选输入如下所示 device 的可选输入如下所示
1. 可选输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu', 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中 1. 可选输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu', 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中
2. torch.device将模型装载到torch.device上 2. torch.device将模型装载到torch.device上
3. int 将使用device_id为该值的gpu进行训练如果值为 -1那么默认使用全部的显卡此时是 `TorchDDPDriver` 3. int 将使用device_id为该值的gpu进行训练如果值为 -1那么默认使用全部的显卡此时是 `TorchDDPDriver`
4. list(int)如果多于1个device应当通过该种方式进行设定 `device` 为一个 list 我们默认使用 `TorchDDPDriver` 4. list(int)如果多于1个device应当通过该种方式进行设定 `device` 为一个 list 我们默认使用 `TorchDDPDriver`
5. None 为None则不对模型进行任何处理 5. None 为None则不对模型进行任何处理
:param n_epochs: 训练总共的 epoch 的数量默认为 20 :param n_epochs: 训练总共的 epoch 的数量默认为 20
:param evaluate_dataloaders: 验证数据集其可以是单独的一个数据集也可以是多个数据集当为多个数据集时注意其必须是 Dict默认 :param evaluate_dataloaders: 验证数据集其可以是单独的一个数据集也可以是多个数据集当为多个数据集时注意其必须是 Dict默认
None None
@ -121,26 +122,27 @@ class Trainer(TrainerEventTrigger):
如果 evaluate_dataloaders metrics 没有提供该参数无意义 如果 evaluate_dataloaders metrics 没有提供该参数无意义
:param larger_better: monitor 的值是否是越大越好 :param larger_better: monitor 的值是否是越大越好
:param marker: 用于标记一个 Trainer 实例从而在用户调用 `Trainer.on` 函数时标记该 callback 函数属于哪一个具体的 'trainer' 实例默认为 None :param marker: 用于标记一个 Trainer 实例从而在用户调用 `Trainer.on` 函数时标记该 callback 函数属于哪一个具体的 'trainer' 实例默认为 None
:param kwargs: 一些其它的可能需要的参数 :param kwargs: 一些其它的可能需要的参数见下方的说明
torch_non_blocking: 表示用于 pytorch tensor to 方法的参数 non_blocking :kwargs:
data_device: 表示如果用户的模型 device Driver 中对应为参数 model_device None 我们会将数据迁移到 data_device * *torch_non_blocking* -- 表示用于 pytorch tensor to 方法的参数 non_blocking
注意如果 model_device None那么 data_device 不会起作用 * *data_device* -- 表示如果用户的模型 device Driver 中对应为参数 model_device None 我们会将数据迁移到 data_device
torch_ddp_kwargs: 用于配置 pytorch DistributedDataParallel 初始化时的参数仅用于 pytorch ddp 训练例如传入 注意如果 model_device None那么 data_device 不会起作用
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等 * *torch_ddp_kwargs* -- 用于配置 pytorch DistributedDataParallel 初始化时的参数仅用于 pytorch ddp 训练例如传入
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None {'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等
use_dist_sampler: 表示是否使用分布式的 sampler 在多卡时分布式 sampler 将自动决定每张卡上读取的 sample 使得一个epoch * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 在多卡时分布式 sampler 将自动决定每张卡上读取的 sample 使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample默认会根据 driver 是否为分布式进行设置 内所有卡的 sample 加起来为一整个数据集的 sample默认会根据 driver 是否为分布式进行设置
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader sampler 替换为分布式的 sampler默认为 True * *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader sampler 替换为分布式的 sampler默认为 True
output_from_new_proc: 应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一 * *output_from_new_proc* -- 应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一
["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到
log 文件中然后将 log 文件保存在通过该参数值设定的文件夹中默认为 "only_error" log 文件中然后将 log 文件保存在通过该参数值设定的文件夹中默认为 "only_error"
progress_bar: 以哪种方式显示 progress 目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象 * *progress_bar* -- 以哪种方式显示 progress 目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback否则使用 RawTextCallback对象如果 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback否则使用 RawTextCallback对象如果
需要定制 progress bar 的参数例如打印频率等可以传入 RichCallback, RawTextCallback 对象 需要定制 progress bar 的参数例如打印频率等可以传入 RichCallback, RawTextCallback 对象
train_input_mapping: input_mapping 一致但是只用于 train input_mapping 互斥 * *train_input_mapping* -- input_mapping 一致但是只用于 train input_mapping 互斥
train_output_mapping: output_mapping 一致但是只用于 train output_mapping 互斥 * *train_output_mapping* -- output_mapping 一致但是只用于 train output_mapping 互斥
evaluate_input_mapping: input_mapping 一致但是只用于 evaluate input_mapping 互斥 * *evaluate_input_mapping* -- input_mapping 一致但是只用于 evaluate input_mapping 互斥
evaluate_output_mapping: output_mapping 一致但是只用于 evaluate output_mapping 互斥 * *evaluate_output_mapping* -- output_mapping 一致但是只用于 evaluate output_mapping 互斥
""" """
self.model = model self.model = model
self.marker = marker self.marker = marker
@ -290,14 +292,14 @@ class Trainer(TrainerEventTrigger):
catch_KeyboardInterrupt=None): catch_KeyboardInterrupt=None):
""" """
注意如果是断点重训的第一次训练即还没有保存任何用于断点重训的文件那么其应当置 resume_from None并且使用 ModelCheckpoint 注意如果是断点重训的第一次训练即还没有保存任何用于断点重训的文件那么其应当置 resume_from None并且使用 ModelCheckpoint
去保存断点重训的文件 去保存断点重训的文件
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止-1 为根据 dataloader 有多少个 batch 决定 :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止-1 为根据 dataloader 有多少个 batch 决定
:param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止-1 为根据 dataloader 有多少个 batch 决定 :param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止-1 为根据 dataloader 有多少个 batch 决定
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误 0 表示不检测 :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误 0 表示不检测
:param resume_from: 从哪个路径下恢复 trainer 的状态 :param resume_from: 从哪个路径下恢复 trainer 的状态
:param resume_training: 是否按照 checkpoint 中训练状态恢复如果为 False则只恢复 model optimizers 的状态 :param resume_training: 是否按照 checkpoint 中训练状态恢复如果为 False则只恢复 model optimizers 的状态
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话不会抛出一场trainer.run()之后的代码会继续运 :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话不会抛出一场trainer.run()之后的代码会继续运
默认如果非 distributed driver catch distributed 不会 catch 无法 catch 默认如果非 distributed driver catch distributed 不会 catch 无法 catch
:return: :return:
""" """
@ -417,39 +419,42 @@ class Trainer(TrainerEventTrigger):
def on(cls, event: Event, marker: Optional[str] = None): def on(cls, event: Event, marker: Optional[str] = None):
r""" r"""
函数修饰器用户可以使用该函数来方便地将一个函数转变为 callback 函数从而进行训练流程中的控制 函数修饰器用户可以使用该函数来方便地将一个函数转变为 callback 函数从而进行训练流程中的控制
支持的 event 时机有以下这些其执行的时机顺序也如下所示每个时机装饰的函数应该接受的参数列表也如下所示例如 支持的 event 时机有以下这些其执行的时机顺序也如下所示每个时机装饰的函数应该接受的参数列表也如下所示例如::
Trainer.__init__():
on_after_trainer_initialized(trainer, driver) Trainer.__init__():
Trainer.run(): on_after_trainer_initialized(trainer, driver)
if num_eval_sanity_batch>0: Trainer.run():
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch if num_eval_sanity_batch>0:
on_sanity_check_end(trainer, sanity_check_res) on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
try: on_sanity_check_end(trainer, sanity_check_res)
on_train_begin(trainer) try:
while cur_epoch_idx < n_epochs: on_train_begin(trainer)
on_train_epoch_begin(trainer) while cur_epoch_idx < n_epochs:
while batch_idx_in_epoch<=num_batches_per_epoch: on_train_epoch_begin(trainer)
on_fetch_data_begin(trainer) while batch_idx_in_epoch<=num_batches_per_epoch:
batch = next(dataloader) on_fetch_data_begin(trainer)
on_fetch_data_end(trainer) batch = next(dataloader)
on_train_batch_begin(trainer, batch, indices) on_fetch_data_end(trainer)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping如果设置了 后的,否则即为 model 的输出。 on_train_batch_begin(trainer, batch, indices)
on_after_backward(trainer) on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping如果设置了 后的,否则即为 model 的输出。
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 on_after_backward(trainer)
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer) on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_epoch_end(trainer) on_train_batch_end(trainer)
except BaseException: on_train_epoch_end(trainer)
self.on_exception(trainer, exception) except BaseException:
finally: self.on_exception(trainer, exception)
on_train_end(trainer) finally:
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run() on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()
特定的时间调用 特定的时间调用
Example:: Example::
from fastNLP import Event from fastNLP import Event
@Trainer.on(Event.on_save_model()) @Trainer.on(Event.on_save_model())
def do_something_1(trainer): def do_something_1(trainer):
@ -696,7 +701,7 @@ class Trainer(TrainerEventTrigger):
r""" r"""
用于断点重训的加载函数 用于断点重训的加载函数
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的因此可能存在一种情况用户只希望加载一个断点重训的状态而在之后不再进行断点重训的 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的因此可能存在一种情况用户只希望加载一个断点重训的状态而在之后不再进行断点重训的
保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleSampler 保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleSampler
注意我们目前不支持单卡到多卡的断点重训 注意我们目前不支持单卡到多卡的断点重训

View File

@ -26,7 +26,8 @@ class State(dict):
为了实现断点重训用户应当保证其保存的信息都是可序列化的 为了实现断点重训用户应当保证其保存的信息都是可序列化的
推荐的使用方式 推荐的使用方式::
>>> state = State() >>> state = State()
>>> state["best_accuracy"] = 0.9 >>> state["best_accuracy"] = 0.9
>>> print(state["best_accuracy"]) >>> print(state["best_accuracy"])

View File

@ -64,38 +64,40 @@ class JittorDataLoader:
:param collate_fn: 对取得到的数据进行打包的callable函数 :param collate_fn: 对取得到的数据进行打包的callable函数
:param as_numpy: 返回数据是否设置为numpy类型否则为torch.tensor类型 :param as_numpy: 返回数据是否设置为numpy类型否则为torch.tensor类型
""" """
# TODO 支持fastnlp dataset
# TODO 验证支持replacesampler (以后完成) # TODO 验证支持replacesampler (以后完成)
# 是否为 jittor 类型的 dataset # FastNLP Datset, collate_fn not None
if isinstance(dataset, FDataSet) and collate_fn is None:
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
if not isinstance(dataset, _JittorDataset):
self.dataset = _JittorDataset(dataset)
if isinstance(collate_fn, str): if isinstance(collate_fn, str):
if collate_fn == "auto": if collate_fn == "auto":
if isinstance(dataset, FDataSet): if isinstance(self.dataset.dataset, FDataSet):
self._collate_fn = dataset.collator self.collate_fn = self.dataset.dataset.collator
self._collate_fn.set_backend(backend="jittor") self.collate_fn.set_backend(backend="jittor")
else: else:
self._collate_fn = Collator(backend="jittor") self.collate_fn = Collator(backend="jittor")
else: else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable): elif isinstance(collate_fn, Callable):
if collate_fn is not collate_batch: if collate_fn is not collate_batch:
self._collate_fn = collate_fn self.collate_fn = collate_fn
else: else:
self._collate_fn = collate_batch self.collate_fn = collate_batch
self.dataset = _JittorDataset(dataset)
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
keep_numpy_array=keep_numpy_array, endless=endless) keep_numpy_array=keep_numpy_array, endless=endless)
# 将内部dataset批次设置为1
if isinstance(self.dataset.dataset, Dataset): if isinstance(self.dataset.dataset, Dataset):
self.dataset.dataset.set_attrs(batch_size=1) self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn则会自动代替 jittor 提供 collate_batch 函数
# self._collate_fn = _collate_fn
self.cur_batch_indices = None self.cur_batch_indices = None
def __iter__(self): def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn设置是无效的 # TODO 第一次迭代后不能设置collate_fn设置是无效的
self.collate_fn = self._collate_fn
if self.cur_batch_indices is None: if self.cur_batch_indices is None:
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn))
for indices, data in self.dataset.__iter__(): for indices, data in self.dataset.__iter__():
@ -107,8 +109,8 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 return (len(self.dataset) - 1) // self.dataset.batch_size + 1
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator: pad_fn: Callable = None) -> "JittorDataLoader":
""" """
如果需要对某个 field 的内容进行特殊的调整请使用这个函数 如果需要对某个 field 的内容进行特殊的调整请使用这个函数
@ -127,16 +129,18 @@ class JittorDataLoader:
形式输出将被直接作为结果输出 形式输出将被直接作为结果输出
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self.collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
return self._collate_fn backend=backend)
return self
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
def set_ignore(self, *field_names) -> Collator: def set_ignore(self, *field_names) -> "JittorDataLoader":
""" """
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略 如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Example:: Example::
collator.set_ignore('field1', 'field2') collator.set_ignore('field1', 'field2')
:param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的 :param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
@ -144,9 +148,9 @@ class JittorDataLoader:
__getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素 __getitem__ 返回的是 Sequence 类型的则可以使用 '_0', '_1' 表示序列中第 0 1 个元素
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self.collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self.collate_fn.set_ignore(*field_names)
return self._collate_fn return self
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
@ -158,5 +162,6 @@ class JittorDataLoader:
""" """
return self.cur_batch_indices return self.cur_batch_indices
def prepare_jittor_dataloader(): def prepare_jittor_dataloader():
... ...

View File

@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
from paddle.io import DataLoader, Dataset, Sampler from paddle.io import DataLoader, Dataset, Sampler
from paddle.fluid.dataloader.collate import default_collate_fn
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as Dataset
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader):
num_workers: int = 0, use_buffer_reader: bool = True, num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0, use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False) -> None: worker_init_fn: Callable = None, persistent_workers=False) -> None:
# FastNLP Datset, collate_fn not None
if isinstance(dataset, FDataSet) and collate_fn is None:
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
if not isinstance(dataset, _PaddleDataset): if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset) dataset = _PaddleDataset(dataset)
@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader):
if isinstance(collate_fn, str): if isinstance(collate_fn, str):
if collate_fn == 'auto': if collate_fn == 'auto':
if isinstance(dataset.dataset, FDataSet): if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle") collate_fn.set_backend(backend="paddle")
else: else:
self._collate_fn = Collator(backend="paddle") collate_fn = Collator(backend="paddle")
else: else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
@ -142,6 +144,7 @@ class PaddleDataLoader(DataLoader):
""" """
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略 如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Example:: Example::
collator.set_ignore('field1', 'field2') collator.set_ignore('field1', 'field2')
:param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的 :param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的
@ -187,7 +190,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
dl_bundle = {} dl_bundle = {}
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
return_list=return_list, return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size, batch_sampler=batch_sampler, batch_size=train_batch_size,
shuffle=shuffle, shuffle=shuffle,
@ -197,7 +200,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
timeout=timeout, worker_init_fn=worker_init_fn, timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers) persistent_workers=persistent_workers)
else: else:
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
return_list=return_list, return_list=return_list,
batch_sampler=batch_sampler, batch_size=non_train_batch_size, batch_sampler=batch_sampler, batch_size=non_train_batch_size,
shuffle=shuffle, shuffle=shuffle,

View File

@ -153,6 +153,7 @@ class TorchDataLoader(DataLoader):
""" """
如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略 如果有的内容不希望输出可以在此处进行设置被设置的 field 将在 batch 的输出中被忽略
Example:: Example::
collator.set_ignore('field1', 'field2') collator.set_ignore('field1', 'field2')
:param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的 :param field_names: 需要忽略的 field 的名称如果 Dataset __getitem__ 方法返回的是 dict 类型的则可以直接使用对应的

View File

@ -706,8 +706,8 @@ class DataSet:
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet':
""" """
将当前dataset与输入的dataset结合成一个更大的dataset需要保证两个dataset都包含了相同的field结合后的dataset的input,target 将当前dataset与输入的dataset结合成一个更大的dataset需要保证两个dataset都包含了相同的field结合后的dataset的input,target
以及collate_fn以当前dataset为准当dataset中包含的field多于当前的dataset则多余的field会被忽略若dataset中未包含所有 以及collate_fn以当前dataset为准当dataset中包含的field多于当前的dataset则多余的field会被忽略若dataset中未包含所有
当前dataset含有field则会报错 当前dataset含有field则会报错
:param DataSet, dataset: 需要和当前dataset concat的dataset :param DataSet, dataset: 需要和当前dataset concat的dataset
:param bool, inplace: 是否直接将dataset组合到当前dataset中 :param bool, inplace: 是否直接将dataset组合到当前dataset中

View File

@ -87,8 +87,8 @@ class Driver(ABC):
:param batch: 当前的一个 batch 的数据可以为字典或者其它类型 :param batch: 当前的一个 batch 的数据可以为字典或者其它类型
:param fn: 调用该函数进行一次计算 :param fn: 调用该函数进行一次计算
:param signature_fn: Trainer 传入的用于网络前向传播一次的签名函数因为当 batch 是一个 Dict 的时候我们会自动调用 auto_param_call :param signature_fn: Trainer 传入的用于网络前向传播一次的签名函数因为当 batch 是一个 Dict 的时候我们会自动调用 auto_param_call
而一些被包裹的模型需要暴露其真正的函数签名例如 DistributedDataParallel 的调用函数是 forward但是需要其函数签名为 model.module.forward 而一些被包裹的模型需要暴露其真正的函数签名例如 DistributedDataParallel 的调用函数是 forward但是需要其函数签名为 model.module.forward
:return: 返回由 `fn` 返回的结果应当为一个 dict 或者 dataclass但是不需要我们去检查 :return: 返回由 `fn` 返回的结果应当为一个 dict 或者 dataclass但是不需要我们去检查
""" """
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")
@ -106,9 +106,10 @@ class Driver(ABC):
`evaluate step fn` 的确定却需要 Evaluator 的初始化因此我们将这一逻辑抽象到这一函数当中 `evaluate step fn` 的确定却需要 Evaluator 的初始化因此我们将这一逻辑抽象到这一函数当中
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数具体逻辑如下所示 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数具体逻辑如下所示
1. 如果 fn == "train_step" or "evaluate_step"那么对传入的模型进行检测如果模型没有定义方法 `fn`则默认调用模型的 `forward` 1. 如果 fn == "train_step" or "evaluate_step"那么对传入的模型进行检测如果模型没有定义方法 `fn`则默认调用模型的 `forward`
函数然后给出 warning 函数然后给出 warning
2. 如果 fn 是其他字符串那么如果模型没有定义方法 `fn` 则直接报错 2. 如果 fn 是其他字符串那么如果模型没有定义方法 `fn` 则直接报错
注意不同的 driver 需要做额外的检测处理例如在 DDPDriver 当传入的模型本身就是 DistributedDataParallel 我们只能调用模型的 注意不同的 driver 需要做额外的检测处理例如在 DDPDriver 当传入的模型本身就是 DistributedDataParallel 我们只能调用模型的
forward 函数因此需要额外的 warning这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变DDPDriver因此 forward 函数因此需要额外的 warning这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变DDPDriver因此
可能需要额外标记最初传入 driver 的模型是哪种形式的 可能需要额外标记最初传入 driver 的模型是哪种形式的
@ -376,7 +377,7 @@ class Driver(ABC):
pid 记录下来然后在出现错误后由出现错误的进程手动地将其它进程 kill pid 记录下来然后在出现错误后由出现错误的进程手动地将其它进程 kill
因此每一个多进程 driver 如果想要该函数能够正确地执行其需要在自己的 open_subprocess开启多进程的函数中正确地记录每一个进程的 因此每一个多进程 driver 如果想要该函数能够正确地执行其需要在自己的 open_subprocess开启多进程的函数中正确地记录每一个进程的
pid 的信息 pid 的信息
""" """
# 单卡 driver 不需要这个函数; # 单卡 driver 不需要这个函数;
if self._pids is not None: if self._pids is not None:

View File

@ -33,11 +33,12 @@ class JittorDriver(Driver):
f"`jittor.Module` type.") f"`jittor.Module` type.")
super(JittorDriver, self).__init__(model) super(JittorDriver, self).__init__(model)
self.model = model
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler() self.grad_scaler = _grad_scaler()
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
@staticmethod @staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader # 在fastnlp中实现了JittorDataLoader
@ -152,4 +153,4 @@ class JittorDriver(Driver):
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx):
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) # dataloader.batch_sampler.set_epoch(cur_epoch_idx)

View File

@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver):
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None return fn, None
elif fn in {"train_step", "evaluate_step"}: elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.forward return self.model, self.model.execute
else: else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver):
return dataloader return dataloader
else: else:
return dataloader return dataloader
def setup(self):
"""
使用单个 GPU jittor 底层自动实现调配无需额外操作
"""
pass

View File

@ -172,6 +172,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
实现任何类型的数据都使用该接口可以进行 all_gather 操作对于非 tensor 类型的数据通过 pickle 序列化再反序列化的方式进行传输 实现任何类型的数据都使用该接口可以进行 all_gather 操作对于非 tensor 类型的数据通过 pickle 序列化再反序列化的方式进行传输
example:: example::
obj = { obj = {
'a': [1, 1], 'a': [1, 1],
'b': [[1, 2], [1, 2]], 'b': [[1, 2], [1, 2]],

View File

@ -534,7 +534,7 @@ class TorchDDPDriver(TorchDriver):
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
""" """
src 端将 obj 对象可能是 tensor 可能是 object 发送到 dst 如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 src 端将 obj 对象可能是 tensor 可能是 object 发送到 dst 如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
传输然后再 dst 处再加载回来仅在分布式的 driver 中有实际意义 传输然后再 dst 处再加载回来仅在分布式的 driver 中有实际意义
:param obj: obj可能是 Tensor 嵌套类型的数据 :param obj: obj可能是 Tensor 嵌套类型的数据
:param int src: source global rank :param int src: source global rank
@ -551,9 +551,10 @@ class TorchDDPDriver(TorchDriver):
def all_gather(self, obj, group) -> List: def all_gather(self, obj, group) -> List:
""" """
obj 互相传送到其它所有的 rank 其中 obj 可能是 Tensor也可能是嵌套结构的 object 如果不是基础类型的数据尝试通过 obj 互相传送到其它所有的 rank 其中 obj 可能是 Tensor也可能是嵌套结构的 object 如果不是基础类型的数据尝试通过
pickle 进行序列化接收到之后再反序列化 pickle 进行序列化接收到之后再反序列化
example::
example:
obj = { obj = {
'a': [1, 1], 'a': [1, 1],
'b': [[1, 2], [1, 2]], 'b': [[1, 2], [1, 2]],

View File

@ -175,7 +175,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) -
""" """
实现任何类型的数据都使用该接口可以进行 all_gather 操作对于非 tensor 类型的数据通过 pickle 序列化再反序列化的方式进行传输 实现任何类型的数据都使用该接口可以进行 all_gather 操作对于非 tensor 类型的数据通过 pickle 序列化再反序列化的方式进行传输
example: example::
obj = { obj = {
'a': [1, 1], 'a': [1, 1],
'b': [[1, 2], [1, 2]], 'b': [[1, 2], [1, 2]],

View File

@ -175,16 +175,18 @@ def _build_fp16_env(dummy=False):
def replace_sampler(dataloader: "DataLoader", sampler): def replace_sampler(dataloader: "DataLoader", sampler):
""" """
替换 sampler 初始化一个新的 dataloader 的逻辑在于 替换 sampler 初始化一个新的 dataloader 的逻辑在于
用户可能继承了 dataloader定制了自己的 dataloader 这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 用户可能继承了 dataloader定制了自己的 dataloader 这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接
`inspect.signature(DataLoader)` 的原因因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader `inspect.signature(DataLoader)` 的原因因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader
的类而不是直接的 DataLoader 的类而不是直接的 DataLoader
如果需要定制自己的 dataloader保证以下两点
1. __init__ 方法中加入 **kwargs这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中
2. __init__ 方法中出现的参数请务必挂为同样名字的实例属性例如 self.one_arg_name = one_arg_name这是因为我们只能通过属性
来获取实际的参数的值
如果需要定制自己的 dataloader保证以下两点
1. __init__ 方法中加入 **kwargs这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中
2. __init__ 方法中出现的参数请务必挂为同样名字的实例属性例如 self.one_arg_name = one_arg_name这是因为我们只能通过属性
来获取实际的参数的值
""" """
# 拿到实例属性; # 拿到实例属性;

View File

@ -1,18 +1,20 @@
r""" r"""
Logger 是fastNLP中记录日志的模块logger封装了logging模块的Logger Logger 是fastNLP中记录日志的模块logger封装了logging模块的Logger
具体使用方式与直接使用logging.Logger相同同时也新增一些简单好用的API 具体使用方式与直接使用logging.Logger相同同时也新增一些简单好用的API
使用方式
from fastNLP import _logger
#
# _logger 可以和 logging.Logger 一样使用
_logger.info('your msg')
_logger.error('your msg')
# _logger 新增的API 使用方式::
# 将日志输出到文件,以及输出的日志等级
_logger.add_file('/path/to/log', level='INFO') from fastNLP import _logger
# 定义在命令行中的显示格式和日志等级 #
_logger.set_stdout('tqdm', level='WARN') # _logger 可以和 logging.Logger 一样使用
_logger.info('your msg')
_logger.error('your msg')
# _logger 新增的API
# 将日志输出到文件,以及输出的日志等级
_logger.add_file('/path/to/log', level='INFO')
# 定义在命令行中的显示格式和日志等级
_logger.set_stdout('tqdm', level='WARN')
""" """

View File

@ -10,12 +10,13 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
用来重定向 print 函数至 logger.info 的函数 用来重定向 print 函数至 logger.info 的函数
Example:: Example::
from fastNLP import print from fastNLP import print
print("This is a test") # 等价于调用了 logger.info("This is a test") print("This is a test") # 等价于调用了 logger.info("This is a test")
:param args: 需要打印的内容 :param args: 需要打印的内容
:param sep: 存在多个输入时使用的间隔 :param sep: 存在多个输入时使用的间隔
:param end: 该参数在当前设置无意义因为结尾一定会被加入 \n :param end: 该参数在当前设置无意义因为结尾一定会被加入 '\\\\n'
:param file: 该参数无意义 :param file: 该参数无意义
:param flush: 该参数无意义 :param flush: 该参数无意义
:return: :return:

View File

@ -38,7 +38,7 @@ class Metric:
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
""" """
注册一个 element 对象注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用可以认为该对象即为对应 backend 注册一个 element 对象注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用可以认为该对象即为对应 backend
tensor 直接进行加减乘除计算即可 tensor 直接进行加减乘除计算即可
注意如果想使得该 metric 可自动扩展到多卡的情况请一定申明 aggregate_method 注意如果想使得该 metric 可自动扩展到多卡的情况请一定申明 aggregate_method
:param name: 当前 element 的名字注册后 Metric 中可以通过 self.{name} 访问该变量 :param name: 当前 element 的名字注册后 Metric 中可以通过 self.{name} 访问该变量
@ -48,7 +48,7 @@ class Metric:
Torch.tensor 如果backend paddle 则该对象为 paddle.tensor 如果 backend jittor , 则该对象为 jittor.Var Torch.tensor 如果backend paddle 则该对象为 paddle.tensor 如果 backend jittor , 则该对象为 jittor.Var
一般情况下直接默认为 auto 就行了fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化例如当传入 一般情况下直接默认为 auto 就行了fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化例如当传入
的参数中只包含 torch.Tensor 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend torch 只包含 的参数中只包含 torch.Tensor 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend torch 只包含
jittor.Var 则认为 backend 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend jittor 如果没有检测 jittor.Var 则认为 backend 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend jittor 如果没有检测
到任何一种 tensor 就默认使用 float 类型作为 element 到任何一种 tensor 就默认使用 float 类型作为 element
:return: 注册的 Element 对象 :return: 注册的 Element 对象
""" """

View File

@ -496,7 +496,7 @@ class PollingSampler(MixSampler):
:param sampler: 实例化好的sampler每个dataset对应一个sampler对象 :param sampler: 实例化好的sampler每个dataset对应一个sampler对象
:param drop_last: 是否去掉最后一个batch的数据其长度小于batch_size :param drop_last: 是否去掉最后一个batch的数据其长度小于batch_size
:param ds_ratio: 当ds_ratio=None时候 轮流采样dataset列表直至所有的数据集采样完当ds_ratio='truncate_to_least' :param ds_ratio: 当ds_ratio=None时候 轮流采样dataset列表直至所有的数据集采样完当ds_ratio='truncate_to_least'
以dataset列表最短的ds为基准长的数据集会被截断当ds_ratio='pad_to_most'以dataset列表最长ds为基准短的数据集会被重采样 以dataset列表最短的ds为基准长的数据集会被截断当ds_ratio='pad_to_most'以dataset列表最长ds为基准短的数据集会被重采样
""" """
super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size, super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size,
sampler=sampler, ds_ratio=ds_ratio, sampler=sampler, ds_ratio=ds_ratio,

View File

@ -35,7 +35,9 @@ class NumConsumedSamplesArray:
def __init__(self, buffer_size=2000, num_consumed_samples=0): def __init__(self, buffer_size=2000, num_consumed_samples=0):
""" """
保留 buffer_size num_consumed_samples 数据可以索引得到某个 index 下的 num_consumed_samples 多少 保留 buffer_size num_consumed_samples 数据可以索引得到某个 index 下的 num_consumed_samples 多少
Example:: Example::
array = NumConsumedSamplesArray(buffer_size=3) array = NumConsumedSamplesArray(buffer_size=3)
for i in range(10): for i in range(10):
array.push(i) array.push(i)

View File

@ -222,7 +222,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec
可以看到第二次运行的时候只用了0.0001s左右是由于第二次运行将直接从cache.pkl这个文件读取数据而不会经过再次预处理 可以看到第二次运行的时候只用了0.0001s左右是由于第二次运行将直接从cache.pkl这个文件读取数据而不会经过再次预处理
如果在函数加上了装饰器@cache_results()则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose, 如果在函数加上了装饰器@cache_results()则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
_check_hash]上面的例子即为使用_cache_fp的情况这五个参数不会传入到被装饰函数中当然被装饰函数参数名也不能包含这五个名称:: _check_hash]上面的例子即为使用_cache_fp的情况这五个参数不会传入到被装饰函数中当然被装饰函数参数名也不能包含这五个名称
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存如果为Nonecache_results没有任何效用除非在 :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存如果为Nonecache_results没有任何效用除非在
函数调用的时候传入 _cache_fp 这个参数保存文件的名称会受到 函数调用的时候传入 _cache_fp 这个参数保存文件的名称会受到

View File

@ -257,12 +257,13 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None,
对于 `output_mapping`该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用 对于 `output_mapping`该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用
转换的逻辑按优先级依次为 转换的逻辑按优先级依次为
1. 如果 `mapping` 是一个函数那么会直接返回 `mapping(data)`
2. 如果 `mapping` 是一个 `Dict`那么 `data` 的类型只能为以下三种 [`Dict`, `dataclass`, `Sequence`] 1. 如果 `mapping` 是一个函数那么会直接返回 `mapping(data)`
如果 `data` `Dict`那么该函数会将 `data` key 替换为 mapping[key] 2. 如果 `mapping` 是一个 `Dict`那么 `data` 的类型只能为以下三种 [`Dict`, `dataclass`, `Sequence`]
如果 `data` `dataclass`那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`然后进行转换 如果 `data` `Dict`那么该函数会将 `data` key 替换为 mapping[key]
如果 `data` `Sequence`那么该函数会先将其转换成一个对应的 `Dict`{"_0": list[0], "_1": list[1], ...}然后使用 如果 `data` `dataclass`那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`然后进行转换
mapping对这个 `Dict` 进行转换如果没有匹配上mapping中的key则保持"_number"这个形式 如果 `data` `Sequence`那么该函数会先将其转换成一个对应的 `Dict`{"_0": list[0], "_1": list[1], ...}然后使用
mapping对这个 `Dict` 进行转换如果没有匹配上mapping中的key则保持"_number"这个形式
:param mapping: 用于转换的字典或者函数mapping是函数时返回值必须为字典类型 :param mapping: 用于转换的字典或者函数mapping是函数时返回值必须为字典类型
:param data: 需要被转换的对象 :param data: 需要被转换的对象
@ -440,12 +441,16 @@ def _is_iterable(value):
def pretty_table_printer(dataset_or_ins) -> PrettyTable: def pretty_table_printer(dataset_or_ins) -> PrettyTable:
r""" r"""
:param dataset_or_ins: 传入一个dataSet或者instance :param dataset_or_ins: 传入一个dataSet或者instance
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+ .. code-block::
| field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+ ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | +-----------+-----------+-----------------+
+-----------+-----------+-----------------+ | field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+-----------+-----------+-----------------+
:return: pretty table的形式返回根据terminal大小进行自动截断 :return: pretty table的形式返回根据terminal大小进行自动截断
""" """
x = PrettyTable() x = PrettyTable()

View File

@ -47,7 +47,7 @@ def rank_zero_call(fn: Callable):
rank_zero_call(add)(1, 2) rank_zero_call(add)(1, 2)
同时该函数还会设置 FASTNLP_NO_SYNC 2在这个环境下所有的 fastNLP 内置的 barrier 接口gather/broadcast 操作都没有任何 同时该函数还会设置 FASTNLP_NO_SYNC 2在这个环境下所有的 fastNLP 内置的 barrier 接口gather/broadcast 操作都没有任何
意义 意义
:param fn: 需要包裹的可执行的函数 :param fn: 需要包裹的可执行的函数
:return: :return:
@ -65,7 +65,7 @@ def rank_zero_call(fn: Callable):
def fastnlp_no_sync_context(level=2): def fastnlp_no_sync_context(level=2):
""" """
用于让 fastNLP barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序如果为 1 表示 fastNLP 里的barrier 操作失效 用于让 fastNLP barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序如果为 1 表示 fastNLP 里的barrier 操作失效
如果为 2 表示 barrier gather/broadcast 都失效 如果为 2 表示 barrier gather/broadcast 都失效
:param int level: 可选 [0, 1, 2] :param int level: 可选 [0, 1, 2]
:return: :return:
@ -84,9 +84,10 @@ def all_rank_call_context():
""" """
在多卡模式下该环境内会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0"使得 rank_zero_call 函数失效使得每个进程都会运行该函数 在多卡模式下该环境内会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0"使得 rank_zero_call 函数失效使得每个进程都会运行该函数
# 使用方式 使用方式::
with all_rank_call_context():
do_something # all rank will do with all_rank_call_context():
do_something # all rank will do
:param fn: :param fn:
:return: :return:

View File

@ -233,8 +233,8 @@ class DataBundle:
如果为False则报错 如果为False则报错
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长 :param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} 就直接跳过这个 dataset :param ignore_miss_dataset: 如果 dataset 没有 {field_name} 就直接跳过这个 dataset
:param progress_desc 当show_progress_barm为True时可以显示当前tqdm正在处理的名称 :param progress_desc: 当show_progress_barm为True时可以显示当前tqdm正在处理的名称
:param show_progress_bar 是否显示tqdm进度条 :param show_progress_bar: 是否显示tqdm进度条
""" """
_progress_desc = progress_desc _progress_desc = progress_desc

View File

@ -0,0 +1,133 @@
import pytest
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.trainer import Evaluator
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import nn, Module
from jittor.dataset import Dataset
class JittorNormalModel_Classification(Module):
"""
基础的 Jittor 分类模型
"""
def __init__(self, num_labels, feature_dimension):
super(JittorNormalModel_Classification, self).__init__()
self.num_labels = num_labels
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=64, out_features=32)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=32, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()
def execute(self, x):
# It's similar to forward function in Pytorch
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
return x
def train_step(self, x, y):
x = self(x)
return {"loss": self.loss_fn(x, y)}
def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}
class JittorRandomMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
super(JittorRandomMaxDataset, self).__init__()
self.x = jt.randn((num_samples, num_features))
self.y = self.x.argmax(dim=1)[0]
def __len__(self):
return len(self.y)
def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}
class TrainJittorConfig:
num_labels: int = 5
feature_dimension: int = 5
lr = 1e-1
batch_size: int = 4
shuffle: bool = True
@pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
def test_trainer_jittor(
driver,
device,
callbacks,
n_epochs=3,
):
model = JittorNormalModel_Classification(
num_labels=TrainJittorConfig.num_labels,
feature_dimension=TrainJittorConfig.feature_dimension
)
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr)
train_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
val_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
test_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
metrics = {"acc": Accuracy()}
trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizer,
train_dataloader=train_dataloader,
evaluate_dataloaders=val_dataloader,
validate_every=-1,
evaluate_fn="evaluate_step",
input_mapping=None,
output_mapping=None,
metrics=metrics,
n_epochs=n_epochs,
callbacks=callbacks,
# progress_bar="rich"
)
trainer.run()
evaluator = Evaluator(
model=model,
driver=driver,
dataloaders=test_dataloader,
evaluate_fn="evaluate_step",
metrics=metrics,
)
metric_results = evaluator.run()
assert metric_results["acc#acc"] > 0.80
if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
pytest.main(['test_trainer_jittor.py']) # 只运行此模块

View File

@ -1,7 +1,6 @@
import pytest import pytest
import numpy as np import numpy as np
from datasets import Dataset as HfDataset from datasets import Dataset as HfDataset
from datasets import load_dataset
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
from fastNLP.core.dataset import DataSet as Fdataset from fastNLP.core.dataset import DataSet as Fdataset
@ -23,16 +22,12 @@ class MyDataset(Dataset):
def __getitem__(self, item): def __getitem__(self, item):
return self.data[item] return self.data[item]
# return {'x': [[1, 0], [2, 0, 1]]}
# return np.random.randn(3, 10)
# def __len__(self):
# return self.dataset_len
@pytest.mark.jittor @pytest.mark.jittor
class TestJittor: class TestJittor:
def test_v1(self): def test_jittor_dataset(self):
""" """
测试jittor类型的dataset使用fdl 测试jittor类型的dataset使用fdl
@ -40,13 +35,13 @@ class TestJittor:
""" """
dataset = MyDataset() dataset = MyDataset()
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
# jtl.set_pad_val('x', 'y')
# jtl.set_input('x')
for batch in jtl: for batch in jtl:
print(batch) assert batch.size() == [4, 3, 4]
print(jtl.get_batch_indices()) jtl1 = JittorDataLoader(dataset, keep_numpy_array=False, batch_size=4, num_workers=2)
for batch in jtl1:
assert batch.size() == [4, 3, 4]
def test_v2(self): def test_fastnlp_Dataset(self):
""" """
测试fastnlp的dataset 测试fastnlp的dataset
@ -56,26 +51,27 @@ class TestJittor:
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True)
jtl.set_pad("x", -1) jtl.set_pad("x", -1)
jtl.set_ignore("y") jtl.set_ignore("y")
# jtl.set_pad_val('x', val=-1)
# jtl.set_input('x', 'y')
for batch in jtl: for batch in jtl:
assert batch['x'].size() == (16, 4) assert batch['x'].size() == (16, 4)
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2)
def test_v3(self):
def test_huggingface_datasets(self):
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True)
# jtl.set_input('x', 'y')
for batch in jtl: for batch in jtl:
print(batch) assert batch['x'].size() == [4, 4]
assert len(batch['y']) == 4
def test_v4(self): def test_num_workers(self):
dataset = MyDataset() dataset = MyDataset()
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) dl = JittorDataLoader(dataset, batch_size=4, num_workers=2)
print(len(dl))
for idx, batch in enumerate(dl): for idx, batch in enumerate(dl):
print(batch.shape, idx) assert batch.shape == [4, 3, 4]
for idx, batch in enumerate(dl): for idx, batch in enumerate(dl):
print(batch.shape, idx) assert batch.shape == [4, 3, 4]
def test_v5(self): def test_v5(self):
dataset = MyDataset() dataset = MyDataset()

View File

@ -6,19 +6,19 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
from paddle.io import Dataset, DataLoader from paddle.io import Dataset
import paddle import paddle
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as Dataset
class RandomDataset(Dataset): class RandomDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image = np.random.random((10, 5)).astype('float32') image = np.random.random((10, 5)).astype('float32')
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} return {'image': paddle.to_tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]}
def __len__(self): def __len__(self):
return 10 return 10
@ -33,16 +33,22 @@ class TestPaddle:
fdl = PaddleDataLoader(ds, batch_size=2) fdl = PaddleDataLoader(ds, batch_size=2)
# fdl = DataLoader(ds, batch_size=2, shuffle=True) # fdl = DataLoader(ds, batch_size=2, shuffle=True)
for batch in fdl: for batch in fdl:
print(batch) assert batch['image'].shape == [2, 10, 5]
assert batch['label'].shape == [2, 2, 4]
# print(fdl.get_batch_indices()) # print(fdl.get_batch_indices())
def test_fdl_batch_indices(self): def test_fdl_fastnlp_dataset(self):
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) fdl = PaddleDataLoader(ds, batch_size=3, shuffle=False, drop_last=True)
fdl.set_ignore('y')
fdl.set_pad('x', -1)
for batch in fdl: for batch in fdl:
assert len(fdl.get_batch_indices()) == 4 assert len(fdl.get_batch_indices()) == 3
print(batch) assert 'y' not in batch
print(fdl.get_batch_indices()) assert batch['x'].shape == [3, 3]
with pytest.raises(ValueError):
PaddleDataLoader(ds, batch_size=3, collate_fn=None)
def test_set_inputs_and_set_pad_val(self): def test_set_inputs_and_set_pad_val(self):
logger.setLevel("DEBUG") logger.setLevel("DEBUG")
@ -50,11 +56,8 @@ class TestPaddle:
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True)
fdl.set_pad('label', -1) fdl.set_pad('label', -1)
for batch in fdl: for batch in fdl:
print(batch['image'])
assert batch['image'].shape == [2, 10, 5] assert batch['image'].shape == [2, 10, 5]
print(batch)
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True)
fdl1.set_ignore('label') fdl1.set_ignore('label')
for batch in fdl1: for batch in fdl1:
assert batch['image'].shape == [4, 10, 5] assert batch['image'].shape == [4, 10, 5]
print(batch)

View File

@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core import Trainer
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch