mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
merge
This commit is contained in:
commit
f1fa665e7c
@ -146,11 +146,13 @@ class CallbackManager:
|
|||||||
r"""
|
r"""
|
||||||
用于断点重训的 callback 的保存函数;
|
用于断点重训的 callback 的保存函数;
|
||||||
该函数主要涉及两个方面:
|
该函数主要涉及两个方面:
|
||||||
|
|
||||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着
|
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着
|
||||||
断点重训应当保存的状态;
|
断点重训应当保存的状态;
|
||||||
2. 每一个具体的 callback 函数的 filter 的状态;
|
2. 每一个具体的 callback 函数的 filter 的状态;
|
||||||
|
|
||||||
:return: 一个包含上述内容的字典::
|
:return: 一个包含上述内容的字典:
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
{
|
{
|
||||||
"callback_name_1": {
|
"callback_name_1": {
|
||||||
@ -158,6 +160,7 @@ class CallbackManager:
|
|||||||
"filter_states": {"on_train_begin": filter1.state_dict(), ...}
|
"filter_states": {"on_train_begin": filter1.state_dict(), ...}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
states = {}
|
states = {}
|
||||||
|
@ -10,13 +10,13 @@ class TorchGradClipCallback(Callback):
|
|||||||
在每次 optimizer update 之前将 parameter 进行 clip
|
在每次 optimizer update 之前将 parameter 进行 clip
|
||||||
|
|
||||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
||||||
:param str clip_type: 支持'norm', 'value'两种::
|
:param str clip_type: 支持'norm', 'value'两种:
|
||||||
|
|
||||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||||
|
2. 'value', 将gradient限制在[-clip_value, clip_value],
|
||||||
2 'value', 将gradient限制在[-clip_value, clip_value],
|
|
||||||
小于-clip_value的gradient被赋值为-clip_value;
|
小于-clip_value的gradient被赋值为-clip_value;
|
||||||
大于clip_value的gradient被赋值为clip_value.
|
大于clip_value的gradient被赋值为clip_value.
|
||||||
|
|
||||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
|
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
|
||||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
|
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
|
||||||
"""
|
"""
|
||||||
|
@ -9,6 +9,7 @@ from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPad
|
|||||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
|
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
|
||||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
|
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
|
||||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
|
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
|
||||||
|
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder
|
||||||
from .exceptions import *
|
from .exceptions import *
|
||||||
|
|
||||||
|
|
||||||
@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
|
|||||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
elif backend == 'paddle':
|
elif backend == 'paddle':
|
||||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
|
elif backend == 'jittor':
|
||||||
|
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")
|
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")
|
||||||
|
|
||||||
@ -103,6 +106,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
|
|||||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
elif backend == 'paddle':
|
elif backend == 'paddle':
|
||||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
|
elif backend == 'jittor':
|
||||||
|
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")
|
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")
|
||||||
|
|
||||||
@ -116,6 +121,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
|
|||||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
|
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
|
||||||
elif backend == 'paddle':
|
elif backend == 'paddle':
|
||||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
|
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
|
||||||
|
elif backend == 'jittor':
|
||||||
|
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")
|
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")
|
||||||
|
|
||||||
|
195
fastNLP/core/collators/padders/jittor_padder.py
Normal file
195
fastNLP/core/collators/padders/jittor_padder.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
__all__ = [
|
||||||
|
'JittorNumberPadder',
|
||||||
|
'JittorSequencePadder',
|
||||||
|
'JittorTensorPadder'
|
||||||
|
]
|
||||||
|
|
||||||
|
from inspect import isclass
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||||
|
|
||||||
|
if _NEED_IMPORT_JITTOR:
|
||||||
|
import jittor
|
||||||
|
|
||||||
|
numpy_to_jittor_dtype_dict = {
|
||||||
|
np.bool_: 'bool',
|
||||||
|
np.uint8: 'uint8',
|
||||||
|
np.int8: "int8",
|
||||||
|
np.int16: "int16",
|
||||||
|
np.int32: "int32",
|
||||||
|
np.int64: "int64",
|
||||||
|
np.float16: "float16",
|
||||||
|
np.float32: 'float32',
|
||||||
|
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
|
||||||
|
}
|
||||||
|
# number_to_jittor_dtype_dict = {
|
||||||
|
# float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64
|
||||||
|
# int: 'int64',
|
||||||
|
# bool: 'bool'
|
||||||
|
# }
|
||||||
|
|
||||||
|
from .padder import Padder
|
||||||
|
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
|
||||||
|
from .exceptions import *
|
||||||
|
|
||||||
|
|
||||||
|
def is_jittor_tensor(dtype):
|
||||||
|
if not isclass(dtype) and isinstance(dtype, jittor.jittor_core.Var):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_jittor_dtype_str(dtype):
|
||||||
|
try:
|
||||||
|
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
|
||||||
|
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
|
||||||
|
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
|
||||||
|
u'int16', u'int32', u'int64', u'uint8'}:
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dtype(ele_dtype, dtype, class_name):
|
||||||
|
if not (ele_dtype is None or (
|
||||||
|
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))):
|
||||||
|
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
|
||||||
|
f"or numpy numbers or jittor.Var but get `{ele_dtype}`.")
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)):
|
||||||
|
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
|
||||||
|
f"or jittor.dtype but get `{dtype}`.")
|
||||||
|
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype)
|
||||||
|
else:
|
||||||
|
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)):
|
||||||
|
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype)
|
||||||
|
# dtype = ele_dtype
|
||||||
|
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
|
||||||
|
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type)
|
||||||
|
if is_numpy_generic_class(ele_dtype):
|
||||||
|
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype)
|
||||||
|
else:
|
||||||
|
dtype = ele_dtype
|
||||||
|
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
|
class JittorNumberPadder(Padder):
|
||||||
|
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
|
||||||
|
"""
|
||||||
|
可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3])
|
||||||
|
|
||||||
|
:param pad_val: 该值无意义
|
||||||
|
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。
|
||||||
|
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等
|
||||||
|
"""
|
||||||
|
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
|
||||||
|
super().__init__(pad_val=pad_val, dtype=dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pad(batch_field, pad_val, dtype):
|
||||||
|
return jittor.Var(np.array(batch_field, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class JittorSequencePadder(Padder):
|
||||||
|
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
|
||||||
|
"""
|
||||||
|
将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。
|
||||||
|
|
||||||
|
:param pad_val: 需要 pad 的值。
|
||||||
|
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。
|
||||||
|
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等
|
||||||
|
"""
|
||||||
|
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
|
||||||
|
super().__init__(pad_val=pad_val, dtype=dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pad(batch_field, pad_val, dtype):
|
||||||
|
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class JittorTensorPadder(Padder):
|
||||||
|
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
|
||||||
|
"""
|
||||||
|
目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。
|
||||||
|
|
||||||
|
:param pad_val: 需要 pad 的值。
|
||||||
|
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。
|
||||||
|
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等
|
||||||
|
"""
|
||||||
|
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
|
||||||
|
super().__init__(pad_val=pad_val, dtype=dtype)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pad(batch_field, pad_val, dtype):
|
||||||
|
try:
|
||||||
|
if not isinstance(batch_field[0], jittor.Var):
|
||||||
|
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field]
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError(f"If the field is not a jittor.Var (it is {type(batch_field[0])}), "
|
||||||
|
f"it must have tolist() method.")
|
||||||
|
|
||||||
|
shapes = [field.shape for field in batch_field]
|
||||||
|
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
|
||||||
|
# if dtype is not None:
|
||||||
|
# tensor = jittor.full(max_shape, pad_val, dtype=dtype)
|
||||||
|
# else:
|
||||||
|
tensor = jittor.full(max_shape, pad_val, dtype=dtype)
|
||||||
|
for i, field in enumerate(batch_field):
|
||||||
|
slices = (i,) + tuple(slice(0, s) for s in shapes[i])
|
||||||
|
tensor[slices] = field
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def fill_tensor(batch_field, padded_batch, dtype):
|
||||||
|
"""
|
||||||
|
将 batch_field 中的值填入到 tensor 中。
|
||||||
|
|
||||||
|
:param batch_field: 需要填充进入 array 中的内容
|
||||||
|
:param padded_batch: 待填充的 tensor
|
||||||
|
:param dtype: 数据的类别
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if padded_batch.ndim == 2:
|
||||||
|
for i, content_i in enumerate(batch_field):
|
||||||
|
padded_batch[i, :len(content_i)] = jittor.Var(np.array(content_i, dtype=dtype))
|
||||||
|
elif padded_batch.ndim == 3:
|
||||||
|
for i, content_i in enumerate(batch_field):
|
||||||
|
for j, content_ii in enumerate(content_i):
|
||||||
|
padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype))
|
||||||
|
elif padded_batch.ndim == 4:
|
||||||
|
try: # 应该是图像,所以直接应该就 ok 了。
|
||||||
|
padded_batch = np.array(batch_field)
|
||||||
|
except:
|
||||||
|
for i, content_i in enumerate(batch_field):
|
||||||
|
for j, content_ii in enumerate(content_i):
|
||||||
|
for k, content_iii in enumerate(content_ii):
|
||||||
|
padded_batch[i, j, k, :len(content_iii)] = jittor.Var(np.array(content_iii, dtype=dtype))
|
||||||
|
elif padded_batch.ndim == 1:
|
||||||
|
padded_batch[:] = jittor.Var(np.array(batch_field, dtype=dtype))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
|
||||||
|
"report.")
|
||||||
|
return padded_batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_padded_jittor_tensor(batch_field, dtype=None, pad_val=0):
|
||||||
|
"""
|
||||||
|
例如:
|
||||||
|
[[1,2], [3]] -> jittor.LongTensor([[1, 2], [3, 0]])
|
||||||
|
|
||||||
|
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
|
||||||
|
/4d(多为图片)。
|
||||||
|
:param dtype: 目标类别是什么
|
||||||
|
:param pad_val: pad 的 value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
shapes = get_shape(batch_field)
|
||||||
|
tensor = jittor.full(shapes, pad_val, dtype=dtype)
|
||||||
|
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
|
||||||
|
return tensor
|
@ -51,22 +51,29 @@ class Evaluator:
|
|||||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`;
|
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`;
|
||||||
:param fp16: 是否使用 fp16 。
|
:param fp16: 是否使用 fp16 。
|
||||||
:param verbose: 是否打印 evaluate 的结果。
|
:param verbose: 是否打印 evaluate 的结果。
|
||||||
:param kwargs:
|
:param \**kwargs:
|
||||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout
|
See below
|
||||||
与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
|
:kwargs:
|
||||||
|
* *model_use_eval_mode* (``bool``) --
|
||||||
|
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的
|
||||||
|
dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
|
||||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
|
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
|
||||||
TODO 还没完成。
|
TODO 还没完成。
|
||||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的
|
* *auto_tensor_conversion_for_metric* (``Union[bool]``) --
|
||||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,
|
是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是
|
||||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数
|
paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,当 auto_tensor_conversion_for_metric 为True时,fastNLP 将
|
||||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定,
|
自动将输出中 paddle 的 tensor (其它非 tensor 的参数不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的
|
||||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。
|
输出 tensor 类型通过 driver 来决定,metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,
|
||||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
|
请使用 input_mapping、output_mapping 参数进行。
|
||||||
|
* *use_dist_sampler* --
|
||||||
|
是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
|
||||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
|
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
|
||||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
* *output_from_new_proc* --
|
||||||
|
应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||||
progress_bar: evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
|
* *progress_bar* --
|
||||||
|
evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
|
||||||
到当前terminal为交互型则使用 rich,否则使用 raw。
|
到当前terminal为交互型则使用 rich,否则使用 raw。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -67,8 +67,8 @@ class Trainer(TrainerEventTrigger):
|
|||||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
|
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
|
||||||
|
|
||||||
:param model: 训练所需要的模型,目前支持 pytorch;
|
:param model: 训练所需要的模型,目前支持 pytorch;
|
||||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle
|
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等
|
||||||
等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
|
国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
|
||||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;
|
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;
|
||||||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
|
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
|
||||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你
|
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你
|
||||||
@ -81,6 +81,7 @@ class Trainer(TrainerEventTrigger):
|
|||||||
3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`;
|
3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`;
|
||||||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
|
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
|
||||||
5. None: 为None则不对模型进行任何处理;
|
5. None: 为None则不对模型进行任何处理;
|
||||||
|
|
||||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
|
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
|
||||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
|
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
|
||||||
为 None;
|
为 None;
|
||||||
@ -121,26 +122,27 @@ class Trainer(TrainerEventTrigger):
|
|||||||
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
|
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
|
||||||
:param larger_better: monitor 的值是否是越大越好。
|
:param larger_better: monitor 的值是否是越大越好。
|
||||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
|
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
|
||||||
:param kwargs: 一些其它的可能需要的参数;
|
:param kwargs: 一些其它的可能需要的参数,见下方的说明
|
||||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
:kwargs:
|
||||||
data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
|
* *torch_non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||||
|
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
|
||||||
注意如果 model_device 为 None,那么 data_device 不会起作用;
|
注意如果 model_device 为 None,那么 data_device 不会起作用;
|
||||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
|
* *torch_ddp_kwargs* -- 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
|
||||||
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
|
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
|
||||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
* *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
||||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
|
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
|
||||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
|
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
|
||||||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
* *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
||||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
|
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
|
||||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
|
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
|
||||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
|
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
|
||||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
||||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
|
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
|
||||||
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
||||||
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
|
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.marker = marker
|
self.marker = marker
|
||||||
@ -417,7 +419,8 @@ class Trainer(TrainerEventTrigger):
|
|||||||
def on(cls, event: Event, marker: Optional[str] = None):
|
def on(cls, event: Event, marker: Optional[str] = None):
|
||||||
r"""
|
r"""
|
||||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
|
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
|
||||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如
|
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如::
|
||||||
|
|
||||||
Trainer.__init__():
|
Trainer.__init__():
|
||||||
on_after_trainer_initialized(trainer, driver)
|
on_after_trainer_initialized(trainer, driver)
|
||||||
Trainer.run():
|
Trainer.run():
|
||||||
@ -445,11 +448,13 @@ class Trainer(TrainerEventTrigger):
|
|||||||
self.on_exception(trainer, exception)
|
self.on_exception(trainer, exception)
|
||||||
finally:
|
finally:
|
||||||
on_train_end(trainer)
|
on_train_end(trainer)
|
||||||
|
|
||||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
|
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
|
||||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
|
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
|
||||||
特定的时间调用。
|
特定的时间调用。
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from fastNLP import Event
|
from fastNLP import Event
|
||||||
@Trainer.on(Event.on_save_model())
|
@Trainer.on(Event.on_save_model())
|
||||||
def do_something_1(trainer):
|
def do_something_1(trainer):
|
||||||
|
@ -26,7 +26,8 @@ class State(dict):
|
|||||||
|
|
||||||
为了实现断点重训,用户应当保证其保存的信息都是可序列化的;
|
为了实现断点重训,用户应当保证其保存的信息都是可序列化的;
|
||||||
|
|
||||||
推荐的使用方式:
|
推荐的使用方式::
|
||||||
|
|
||||||
>>> state = State()
|
>>> state = State()
|
||||||
>>> state["best_accuracy"] = 0.9
|
>>> state["best_accuracy"] = 0.9
|
||||||
>>> print(state["best_accuracy"])
|
>>> print(state["best_accuracy"])
|
||||||
|
@ -64,38 +64,40 @@ class JittorDataLoader:
|
|||||||
:param collate_fn: 对取得到的数据进行打包的callable函数
|
:param collate_fn: 对取得到的数据进行打包的callable函数
|
||||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
|
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
|
||||||
"""
|
"""
|
||||||
# TODO 支持fastnlp dataset
|
|
||||||
# TODO 验证支持replacesampler (以后完成)
|
# TODO 验证支持replacesampler (以后完成)
|
||||||
# 是否为 jittor 类型的 dataset
|
# FastNLP Datset, collate_fn not None
|
||||||
|
if isinstance(dataset, FDataSet) and collate_fn is None:
|
||||||
|
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
|
||||||
|
|
||||||
|
if not isinstance(dataset, _JittorDataset):
|
||||||
|
self.dataset = _JittorDataset(dataset)
|
||||||
|
|
||||||
if isinstance(collate_fn, str):
|
if isinstance(collate_fn, str):
|
||||||
if collate_fn == "auto":
|
if collate_fn == "auto":
|
||||||
if isinstance(dataset, FDataSet):
|
if isinstance(self.dataset.dataset, FDataSet):
|
||||||
self._collate_fn = dataset.collator
|
self.collate_fn = self.dataset.dataset.collator
|
||||||
self._collate_fn.set_backend(backend="jittor")
|
self.collate_fn.set_backend(backend="jittor")
|
||||||
else:
|
else:
|
||||||
self._collate_fn = Collator(backend="jittor")
|
self.collate_fn = Collator(backend="jittor")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
|
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
|
||||||
elif isinstance(collate_fn, Callable):
|
elif isinstance(collate_fn, Callable):
|
||||||
if collate_fn is not collate_batch:
|
if collate_fn is not collate_batch:
|
||||||
self._collate_fn = collate_fn
|
self.collate_fn = collate_fn
|
||||||
else:
|
else:
|
||||||
self._collate_fn = collate_batch
|
self.collate_fn = collate_batch
|
||||||
|
|
||||||
self.dataset = _JittorDataset(dataset)
|
|
||||||
|
|
||||||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
|
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
|
||||||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
|
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
|
||||||
keep_numpy_array=keep_numpy_array, endless=endless)
|
keep_numpy_array=keep_numpy_array, endless=endless)
|
||||||
|
# 将内部dataset批次设置为1
|
||||||
if isinstance(self.dataset.dataset, Dataset):
|
if isinstance(self.dataset.dataset, Dataset):
|
||||||
self.dataset.dataset.set_attrs(batch_size=1)
|
self.dataset.dataset.set_attrs(batch_size=1)
|
||||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
|
|
||||||
# self._collate_fn = _collate_fn
|
|
||||||
self.cur_batch_indices = None
|
self.cur_batch_indices = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的
|
# TODO 第一次迭代后不能设置collate_fn,设置是无效的
|
||||||
self.collate_fn = self._collate_fn
|
|
||||||
if self.cur_batch_indices is None:
|
if self.cur_batch_indices is None:
|
||||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn))
|
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn))
|
||||||
for indices, data in self.dataset.__iter__():
|
for indices, data in self.dataset.__iter__():
|
||||||
@ -107,8 +109,8 @@ class JittorDataLoader:
|
|||||||
return len(self.dataset) // self.dataset.batch_size
|
return len(self.dataset) // self.dataset.batch_size
|
||||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
|
return (len(self.dataset) - 1) // self.dataset.batch_size + 1
|
||||||
|
|
||||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
|
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
|
||||||
pad_fn:Callable=None) -> Collator:
|
pad_fn: Callable = None) -> "JittorDataLoader":
|
||||||
"""
|
"""
|
||||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
|
||||||
|
|
||||||
@ -127,16 +129,18 @@ class JittorDataLoader:
|
|||||||
形式,输出将被直接作为结果输出。
|
形式,输出将被直接作为结果输出。
|
||||||
:return: 返回 Collator 自身
|
:return: 返回 Collator 自身
|
||||||
"""
|
"""
|
||||||
if isinstance(self._collate_fn, Collator):
|
if isinstance(self.collate_fn, Collator):
|
||||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
|
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
|
||||||
return self._collate_fn
|
backend=backend)
|
||||||
|
return self
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")
|
||||||
|
|
||||||
def set_ignore(self, *field_names) -> Collator:
|
def set_ignore(self, *field_names) -> "JittorDataLoader":
|
||||||
"""
|
"""
|
||||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
collator.set_ignore('field1', 'field2')
|
collator.set_ignore('field1', 'field2')
|
||||||
|
|
||||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||||
@ -144,9 +148,9 @@ class JittorDataLoader:
|
|||||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
|
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
|
||||||
:return: 返回 Collator 自身
|
:return: 返回 Collator 自身
|
||||||
"""
|
"""
|
||||||
if isinstance(self._collate_fn, Collator):
|
if isinstance(self.collate_fn, Collator):
|
||||||
self._collate_fn.set_ignore(*field_names)
|
self.collate_fn.set_ignore(*field_names)
|
||||||
return self._collate_fn
|
return self
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||||
|
|
||||||
@ -158,5 +162,6 @@ class JittorDataLoader:
|
|||||||
"""
|
"""
|
||||||
return self.cur_batch_indices
|
return self.cur_batch_indices
|
||||||
|
|
||||||
|
|
||||||
def prepare_jittor_dataloader():
|
def prepare_jittor_dataloader():
|
||||||
...
|
...
|
||||||
|
@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
|||||||
|
|
||||||
if _NEED_IMPORT_PADDLE:
|
if _NEED_IMPORT_PADDLE:
|
||||||
from paddle.io import DataLoader, Dataset, Sampler
|
from paddle.io import DataLoader, Dataset, Sampler
|
||||||
from paddle.fluid.dataloader.collate import default_collate_fn
|
|
||||||
else:
|
else:
|
||||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
||||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
|
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader
|
||||||
@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader):
|
|||||||
num_workers: int = 0, use_buffer_reader: bool = True,
|
num_workers: int = 0, use_buffer_reader: bool = True,
|
||||||
use_shared_memory: bool = True, timeout: int = 0,
|
use_shared_memory: bool = True, timeout: int = 0,
|
||||||
worker_init_fn: Callable = None, persistent_workers=False) -> None:
|
worker_init_fn: Callable = None, persistent_workers=False) -> None:
|
||||||
|
# FastNLP Datset, collate_fn not None
|
||||||
|
if isinstance(dataset, FDataSet) and collate_fn is None:
|
||||||
|
raise ValueError("When use FastNLP DataSet, collate_fn must be not None")
|
||||||
|
|
||||||
if not isinstance(dataset, _PaddleDataset):
|
if not isinstance(dataset, _PaddleDataset):
|
||||||
dataset = _PaddleDataset(dataset)
|
dataset = _PaddleDataset(dataset)
|
||||||
@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader):
|
|||||||
if isinstance(collate_fn, str):
|
if isinstance(collate_fn, str):
|
||||||
if collate_fn == 'auto':
|
if collate_fn == 'auto':
|
||||||
if isinstance(dataset.dataset, FDataSet):
|
if isinstance(dataset.dataset, FDataSet):
|
||||||
self._collate_fn = dataset.dataset.collator
|
collate_fn = dataset.dataset.collator
|
||||||
self._collate_fn.set_backend(backend="paddle")
|
collate_fn.set_backend(backend="paddle")
|
||||||
else:
|
else:
|
||||||
self._collate_fn = Collator(backend="paddle")
|
collate_fn = Collator(backend="paddle")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
|
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
|
||||||
@ -142,6 +144,7 @@ class PaddleDataLoader(DataLoader):
|
|||||||
"""
|
"""
|
||||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
collator.set_ignore('field1', 'field2')
|
collator.set_ignore('field1', 'field2')
|
||||||
|
|
||||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||||
@ -187,7 +190,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
|
|||||||
dl_bundle = {}
|
dl_bundle = {}
|
||||||
for name, ds in ds_or_db.iter_datasets():
|
for name, ds in ds_or_db.iter_datasets():
|
||||||
if 'train' in name:
|
if 'train' in name:
|
||||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places,
|
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
|
||||||
return_list=return_list,
|
return_list=return_list,
|
||||||
batch_sampler=batch_sampler, batch_size=train_batch_size,
|
batch_sampler=batch_sampler, batch_size=train_batch_size,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
@ -197,7 +200,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
|
|||||||
timeout=timeout, worker_init_fn=worker_init_fn,
|
timeout=timeout, worker_init_fn=worker_init_fn,
|
||||||
persistent_workers=persistent_workers)
|
persistent_workers=persistent_workers)
|
||||||
else:
|
else:
|
||||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places,
|
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places,
|
||||||
return_list=return_list,
|
return_list=return_list,
|
||||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size,
|
batch_sampler=batch_sampler, batch_size=non_train_batch_size,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
|
@ -153,6 +153,7 @@ class TorchDataLoader(DataLoader):
|
|||||||
"""
|
"""
|
||||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
collator.set_ignore('field1', 'field2')
|
collator.set_ignore('field1', 'field2')
|
||||||
|
|
||||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||||
|
@ -87,8 +87,8 @@ class Driver(ABC):
|
|||||||
|
|
||||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
|
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
|
||||||
:param fn: 调用该函数进行一次计算。
|
:param fn: 调用该函数进行一次计算。
|
||||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
|
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函
|
||||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
||||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
|
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")
|
||||||
@ -109,6 +109,7 @@ class Driver(ABC):
|
|||||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||||
函数,然后给出 warning;
|
函数,然后给出 warning;
|
||||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||||
|
|
||||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
||||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||||
|
@ -33,11 +33,12 @@ class JittorDriver(Driver):
|
|||||||
f"`jittor.Module` type.")
|
f"`jittor.Module` type.")
|
||||||
super(JittorDriver, self).__init__(model)
|
super(JittorDriver, self).__init__(model)
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
|
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
|
||||||
self.grad_scaler = _grad_scaler()
|
self.grad_scaler = _grad_scaler()
|
||||||
|
|
||||||
|
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
|
||||||
|
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
||||||
# 在fastnlp中实现了JittorDataLoader
|
# 在fastnlp中实现了JittorDataLoader
|
||||||
|
@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver):
|
|||||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
|
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
|
||||||
return fn, None
|
return fn, None
|
||||||
elif fn in {"train_step", "evaluate_step"}:
|
elif fn in {"train_step", "evaluate_step"}:
|
||||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
|
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
|
||||||
return self.model, self.model.forward
|
return self.model, self.model.execute
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
|
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
|
||||||
|
|
||||||
@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver):
|
|||||||
return dataloader
|
return dataloader
|
||||||
else:
|
else:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
"""
|
||||||
|
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -172,6 +172,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
|
|||||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||||
|
|
||||||
example::
|
example::
|
||||||
|
|
||||||
obj = {
|
obj = {
|
||||||
'a': [1, 1],
|
'a': [1, 1],
|
||||||
'b': [[1, 2], [1, 2]],
|
'b': [[1, 2], [1, 2]],
|
||||||
|
@ -553,7 +553,8 @@ class TorchDDPDriver(TorchDriver):
|
|||||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
||||||
pickle 进行序列化,接收到之后再反序列化。
|
pickle 进行序列化,接收到之后再反序列化。
|
||||||
|
|
||||||
example:
|
example::
|
||||||
|
|
||||||
obj = {
|
obj = {
|
||||||
'a': [1, 1],
|
'a': [1, 1],
|
||||||
'b': [[1, 2], [1, 2]],
|
'b': [[1, 2], [1, 2]],
|
||||||
|
@ -175,7 +175,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) -
|
|||||||
"""
|
"""
|
||||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||||
|
|
||||||
example:
|
example::
|
||||||
|
|
||||||
obj = {
|
obj = {
|
||||||
'a': [1, 1],
|
'a': [1, 1],
|
||||||
'b': [[1, 2], [1, 2]],
|
'b': [[1, 2], [1, 2]],
|
||||||
|
@ -182,9 +182,11 @@ def replace_sampler(dataloader: "DataLoader", sampler):
|
|||||||
的类,而不是直接的 DataLoader;
|
的类,而不是直接的 DataLoader;
|
||||||
|
|
||||||
如果需要定制自己的 dataloader,保证以下两点:
|
如果需要定制自己的 dataloader,保证以下两点:
|
||||||
|
|
||||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
|
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
|
||||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
|
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
|
||||||
来获取实际的参数的值;
|
来获取实际的参数的值;
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 拿到实例属性;
|
# 拿到实例属性;
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
r"""
|
r"""
|
||||||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger,
|
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger,
|
||||||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API
|
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API
|
||||||
使用方式:
|
|
||||||
from fastNLP import _logger
|
|
||||||
#
|
|
||||||
# _logger 可以和 logging.Logger 一样使用
|
|
||||||
_logger.info('your msg')
|
|
||||||
_logger.error('your msg')
|
|
||||||
|
|
||||||
# _logger 新增的API
|
使用方式::
|
||||||
# 将日志输出到文件,以及输出的日志等级
|
|
||||||
_logger.add_file('/path/to/log', level='INFO')
|
from fastNLP import _logger
|
||||||
# 定义在命令行中的显示格式和日志等级
|
#
|
||||||
_logger.set_stdout('tqdm', level='WARN')
|
# _logger 可以和 logging.Logger 一样使用
|
||||||
|
_logger.info('your msg')
|
||||||
|
_logger.error('your msg')
|
||||||
|
|
||||||
|
# _logger 新增的API
|
||||||
|
# 将日志输出到文件,以及输出的日志等级
|
||||||
|
_logger.add_file('/path/to/log', level='INFO')
|
||||||
|
# 定义在命令行中的显示格式和日志等级
|
||||||
|
_logger.set_stdout('tqdm', level='WARN')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -10,12 +10,13 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
|
|||||||
用来重定向 print 函数至 logger.info 的函数。
|
用来重定向 print 函数至 logger.info 的函数。
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from fastNLP import print
|
from fastNLP import print
|
||||||
print("This is a test") # 等价于调用了 logger.info("This is a test")
|
print("This is a test") # 等价于调用了 logger.info("This is a test")
|
||||||
|
|
||||||
:param args: 需要打印的内容
|
:param args: 需要打印的内容
|
||||||
:param sep: 存在多个输入时,使用的间隔。
|
:param sep: 存在多个输入时,使用的间隔。
|
||||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。
|
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 '\\\\n' 。
|
||||||
:param file: 该参数无意义。
|
:param file: 该参数无意义。
|
||||||
:param flush: 该参数无意义。
|
:param flush: 该参数无意义。
|
||||||
:return:
|
:return:
|
||||||
|
@ -35,7 +35,9 @@ class NumConsumedSamplesArray:
|
|||||||
def __init__(self, buffer_size=2000, num_consumed_samples=0):
|
def __init__(self, buffer_size=2000, num_consumed_samples=0):
|
||||||
"""
|
"""
|
||||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
|
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
array = NumConsumedSamplesArray(buffer_size=3)
|
array = NumConsumedSamplesArray(buffer_size=3)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
array.push(i)
|
array.push(i)
|
||||||
|
@ -222,7 +222,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec
|
|||||||
|
|
||||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
|
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
|
||||||
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
|
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
|
||||||
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称::
|
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称。
|
||||||
|
|
||||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
|
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
|
||||||
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到
|
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到
|
||||||
|
@ -257,6 +257,7 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None,
|
|||||||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用;
|
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用;
|
||||||
|
|
||||||
转换的逻辑按优先级依次为:
|
转换的逻辑按优先级依次为:
|
||||||
|
|
||||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
|
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
|
||||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
|
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
|
||||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
|
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
|
||||||
@ -440,12 +441,16 @@ def _is_iterable(value):
|
|||||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable:
|
def pretty_table_printer(dataset_or_ins) -> PrettyTable:
|
||||||
r"""
|
r"""
|
||||||
:param dataset_or_ins: 传入一个dataSet或者instance
|
:param dataset_or_ins: 传入一个dataSet或者instance
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
|
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
|
||||||
+-----------+-----------+-----------------+
|
+-----------+-----------+-----------------+
|
||||||
| field_1 | field_2 | field_3 |
|
| field_1 | field_2 | field_3 |
|
||||||
+-----------+-----------+-----------------+
|
+-----------+-----------+-----------------+
|
||||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
|
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
|
||||||
+-----------+-----------+-----------------+
|
+-----------+-----------+-----------------+
|
||||||
|
|
||||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
|
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
|
||||||
"""
|
"""
|
||||||
x = PrettyTable()
|
x = PrettyTable()
|
||||||
|
@ -84,7 +84,8 @@ def all_rank_call_context():
|
|||||||
"""
|
"""
|
||||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。
|
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。
|
||||||
|
|
||||||
# 使用方式
|
使用方式::
|
||||||
|
|
||||||
with all_rank_call_context():
|
with all_rank_call_context():
|
||||||
do_something # all rank will do
|
do_something # all rank will do
|
||||||
|
|
||||||
|
@ -233,8 +233,8 @@ class DataBundle:
|
|||||||
如果为False,则报错
|
如果为False,则报错
|
||||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
|
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
|
||||||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
||||||
:param show_progress_bar 是否显示tqdm进度条
|
:param show_progress_bar: 是否显示tqdm进度条
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_progress_desc = progress_desc
|
_progress_desc = progress_desc
|
||||||
|
133
tests/core/controllers/test_trainer_jittor.py
Normal file
133
tests/core/controllers/test_trainer_jittor.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastNLP.core.controllers.trainer import Trainer
|
||||||
|
from fastNLP.core.controllers.trainer import Evaluator
|
||||||
|
from fastNLP.core.metrics.accuracy import Accuracy
|
||||||
|
from fastNLP.core.callbacks.progress_callback import RichCallback
|
||||||
|
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader
|
||||||
|
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||||
|
|
||||||
|
if _NEED_IMPORT_JITTOR:
|
||||||
|
import jittor as jt
|
||||||
|
from jittor import nn, Module
|
||||||
|
from jittor.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class JittorNormalModel_Classification(Module):
|
||||||
|
"""
|
||||||
|
基础的 Jittor 分类模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_labels, feature_dimension):
|
||||||
|
super(JittorNormalModel_Classification, self).__init__()
|
||||||
|
self.num_labels = num_labels
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
|
||||||
|
self.ac1 = nn.ReLU()
|
||||||
|
self.linear2 = nn.Linear(in_features=64, out_features=32)
|
||||||
|
self.ac2 = nn.ReLU()
|
||||||
|
self.output = nn.Linear(in_features=32, out_features=num_labels)
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def execute(self, x):
|
||||||
|
# It's similar to forward function in Pytorch
|
||||||
|
x = self.ac1(self.linear1(x))
|
||||||
|
x = self.ac2(self.linear2(x))
|
||||||
|
x = self.output(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def train_step(self, x, y):
|
||||||
|
x = self(x)
|
||||||
|
return {"loss": self.loss_fn(x, y)}
|
||||||
|
|
||||||
|
def evaluate_step(self, x, y):
|
||||||
|
x = self(x)
|
||||||
|
return {"pred": x, "target": y.reshape((-1,))}
|
||||||
|
|
||||||
|
|
||||||
|
class JittorRandomMaxDataset(Dataset):
|
||||||
|
def __init__(self, num_samples, num_features):
|
||||||
|
super(JittorRandomMaxDataset, self).__init__()
|
||||||
|
self.x = jt.randn((num_samples, num_features))
|
||||||
|
self.y = self.x.argmax(dim=1)[0]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.y)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return {"x": self.x[item], "y": self.y[item]}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainJittorConfig:
|
||||||
|
num_labels: int = 5
|
||||||
|
feature_dimension: int = 5
|
||||||
|
lr = 1e-1
|
||||||
|
batch_size: int = 4
|
||||||
|
shuffle: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("driver,device", [("jittor", None)])
|
||||||
|
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
|
||||||
|
def test_trainer_jittor(
|
||||||
|
driver,
|
||||||
|
device,
|
||||||
|
callbacks,
|
||||||
|
n_epochs=3,
|
||||||
|
):
|
||||||
|
model = JittorNormalModel_Classification(
|
||||||
|
num_labels=TrainJittorConfig.num_labels,
|
||||||
|
feature_dimension=TrainJittorConfig.feature_dimension
|
||||||
|
)
|
||||||
|
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr)
|
||||||
|
train_dataloader = JittorDataLoader(
|
||||||
|
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
|
||||||
|
batch_size=TrainJittorConfig.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
# num_workers=4,
|
||||||
|
)
|
||||||
|
val_dataloader = JittorDataLoader(
|
||||||
|
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension),
|
||||||
|
batch_size=TrainJittorConfig.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
# num_workers=4,
|
||||||
|
)
|
||||||
|
test_dataloader = JittorDataLoader(
|
||||||
|
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
|
||||||
|
batch_size=TrainJittorConfig.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
# num_workers=4,
|
||||||
|
)
|
||||||
|
metrics = {"acc": Accuracy()}
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
driver=driver,
|
||||||
|
device=device,
|
||||||
|
optimizers=optimizer,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
evaluate_dataloaders=val_dataloader,
|
||||||
|
validate_every=-1,
|
||||||
|
evaluate_fn="evaluate_step",
|
||||||
|
input_mapping=None,
|
||||||
|
output_mapping=None,
|
||||||
|
metrics=metrics,
|
||||||
|
n_epochs=n_epochs,
|
||||||
|
callbacks=callbacks,
|
||||||
|
# progress_bar="rich"
|
||||||
|
)
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
evaluator = Evaluator(
|
||||||
|
model=model,
|
||||||
|
driver=driver,
|
||||||
|
dataloaders=test_dataloader,
|
||||||
|
evaluate_fn="evaluate_step",
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
metric_results = evaluator.run()
|
||||||
|
assert metric_results["acc#acc"] > 0.80
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test_trainer_jittor("jittor", None, [RichCallback(100)])
|
||||||
|
pytest.main(['test_trainer_jittor.py']) # 只运行此模块
|
@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset as HfDataset
|
from datasets import Dataset as HfDataset
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
|
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
|
||||||
from fastNLP.core.dataset import DataSet as Fdataset
|
from fastNLP.core.dataset import DataSet as Fdataset
|
||||||
@ -23,16 +22,12 @@ class MyDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.data[item]
|
return self.data[item]
|
||||||
# return {'x': [[1, 0], [2, 0, 1]]}
|
|
||||||
# return np.random.randn(3, 10)
|
|
||||||
|
|
||||||
# def __len__(self):
|
|
||||||
# return self.dataset_len
|
|
||||||
|
|
||||||
@pytest.mark.jittor
|
@pytest.mark.jittor
|
||||||
class TestJittor:
|
class TestJittor:
|
||||||
|
|
||||||
def test_v1(self):
|
def test_jittor_dataset(self):
|
||||||
"""
|
"""
|
||||||
测试jittor类型的dataset使用fdl
|
测试jittor类型的dataset使用fdl
|
||||||
|
|
||||||
@ -40,13 +35,13 @@ class TestJittor:
|
|||||||
"""
|
"""
|
||||||
dataset = MyDataset()
|
dataset = MyDataset()
|
||||||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
|
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4)
|
||||||
# jtl.set_pad_val('x', 'y')
|
|
||||||
# jtl.set_input('x')
|
|
||||||
for batch in jtl:
|
for batch in jtl:
|
||||||
print(batch)
|
assert batch.size() == [4, 3, 4]
|
||||||
print(jtl.get_batch_indices())
|
jtl1 = JittorDataLoader(dataset, keep_numpy_array=False, batch_size=4, num_workers=2)
|
||||||
|
for batch in jtl1:
|
||||||
|
assert batch.size() == [4, 3, 4]
|
||||||
|
|
||||||
def test_v2(self):
|
def test_fastnlp_Dataset(self):
|
||||||
"""
|
"""
|
||||||
测试fastnlp的dataset
|
测试fastnlp的dataset
|
||||||
|
|
||||||
@ -56,26 +51,27 @@ class TestJittor:
|
|||||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True)
|
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True)
|
||||||
jtl.set_pad("x", -1)
|
jtl.set_pad("x", -1)
|
||||||
jtl.set_ignore("y")
|
jtl.set_ignore("y")
|
||||||
# jtl.set_pad_val('x', val=-1)
|
|
||||||
# jtl.set_input('x', 'y')
|
|
||||||
for batch in jtl:
|
for batch in jtl:
|
||||||
assert batch['x'].size() == (16, 4)
|
assert batch['x'].size() == (16, 4)
|
||||||
|
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2)
|
||||||
|
|
||||||
def test_v3(self):
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_datasets(self):
|
||||||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
|
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100})
|
||||||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True)
|
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True)
|
||||||
# jtl.set_input('x', 'y')
|
|
||||||
for batch in jtl:
|
for batch in jtl:
|
||||||
print(batch)
|
assert batch['x'].size() == [4, 4]
|
||||||
|
assert len(batch['y']) == 4
|
||||||
|
|
||||||
def test_v4(self):
|
def test_num_workers(self):
|
||||||
dataset = MyDataset()
|
dataset = MyDataset()
|
||||||
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2)
|
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2)
|
||||||
print(len(dl))
|
|
||||||
for idx, batch in enumerate(dl):
|
for idx, batch in enumerate(dl):
|
||||||
print(batch.shape, idx)
|
assert batch.shape == [4, 3, 4]
|
||||||
for idx, batch in enumerate(dl):
|
for idx, batch in enumerate(dl):
|
||||||
print(batch.shape, idx)
|
assert batch.shape == [4, 3, 4]
|
||||||
|
|
||||||
def test_v5(self):
|
def test_v5(self):
|
||||||
dataset = MyDataset()
|
dataset = MyDataset()
|
||||||
|
@ -6,19 +6,19 @@ from fastNLP.core.dataset import DataSet
|
|||||||
from fastNLP.core.log import logger
|
from fastNLP.core.log import logger
|
||||||
|
|
||||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||||
|
|
||||||
if _NEED_IMPORT_PADDLE:
|
if _NEED_IMPORT_PADDLE:
|
||||||
from paddle.io import Dataset, DataLoader
|
from paddle.io import Dataset
|
||||||
import paddle
|
import paddle
|
||||||
else:
|
else:
|
||||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
image = np.random.random((10, 5)).astype('float32')
|
image = np.random.random((10, 5)).astype('float32')
|
||||||
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]}
|
return {'image': paddle.to_tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return 10
|
return 10
|
||||||
@ -33,16 +33,22 @@ class TestPaddle:
|
|||||||
fdl = PaddleDataLoader(ds, batch_size=2)
|
fdl = PaddleDataLoader(ds, batch_size=2)
|
||||||
# fdl = DataLoader(ds, batch_size=2, shuffle=True)
|
# fdl = DataLoader(ds, batch_size=2, shuffle=True)
|
||||||
for batch in fdl:
|
for batch in fdl:
|
||||||
print(batch)
|
assert batch['image'].shape == [2, 10, 5]
|
||||||
|
assert batch['label'].shape == [2, 2, 4]
|
||||||
# print(fdl.get_batch_indices())
|
# print(fdl.get_batch_indices())
|
||||||
|
|
||||||
def test_fdl_batch_indices(self):
|
def test_fdl_fastnlp_dataset(self):
|
||||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
|
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
|
||||||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True)
|
fdl = PaddleDataLoader(ds, batch_size=3, shuffle=False, drop_last=True)
|
||||||
|
fdl.set_ignore('y')
|
||||||
|
fdl.set_pad('x', -1)
|
||||||
for batch in fdl:
|
for batch in fdl:
|
||||||
assert len(fdl.get_batch_indices()) == 4
|
assert len(fdl.get_batch_indices()) == 3
|
||||||
print(batch)
|
assert 'y' not in batch
|
||||||
print(fdl.get_batch_indices())
|
assert batch['x'].shape == [3, 3]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PaddleDataLoader(ds, batch_size=3, collate_fn=None)
|
||||||
|
|
||||||
def test_set_inputs_and_set_pad_val(self):
|
def test_set_inputs_and_set_pad_val(self):
|
||||||
logger.setLevel("DEBUG")
|
logger.setLevel("DEBUG")
|
||||||
@ -50,11 +56,8 @@ class TestPaddle:
|
|||||||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True)
|
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True)
|
||||||
fdl.set_pad('label', -1)
|
fdl.set_pad('label', -1)
|
||||||
for batch in fdl:
|
for batch in fdl:
|
||||||
print(batch['image'])
|
|
||||||
assert batch['image'].shape == [2, 10, 5]
|
assert batch['image'].shape == [2, 10, 5]
|
||||||
print(batch)
|
|
||||||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True)
|
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True)
|
||||||
fdl1.set_ignore('label')
|
fdl1.set_ignore('label')
|
||||||
for batch in fdl1:
|
for batch in fdl1:
|
||||||
assert batch['image'].shape == [4, 10, 5]
|
assert batch['image'].shape == [4, 10, 5]
|
||||||
print(batch)
|
|
||||||
|
@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t
|
|||||||
from fastNLP.core.dataset import DataSet
|
from fastNLP.core.dataset import DataSet
|
||||||
from fastNLP.io.data_bundle import DataBundle
|
from fastNLP.io.data_bundle import DataBundle
|
||||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||||
|
from fastNLP.core import Trainer
|
||||||
|
|
||||||
if _NEED_IMPORT_TORCH:
|
if _NEED_IMPORT_TORCH:
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
Reference in New Issue
Block a user