mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
修正部分文档的格式问题
This commit is contained in:
parent
79d718d23f
commit
734affed76
@ -146,11 +146,13 @@ class CallbackManager:
|
||||
r"""
|
||||
用于断点重训的 callback 的保存函数;
|
||||
该函数主要涉及两个方面:
|
||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着
|
||||
断点重训应当保存的状态;
|
||||
2. 每一个具体的 callback 函数的 filter 的状态;
|
||||
|
||||
:return: 一个包含上述内容的字典::
|
||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着
|
||||
断点重训应当保存的状态;
|
||||
2. 每一个具体的 callback 函数的 filter 的状态;
|
||||
|
||||
:return: 一个包含上述内容的字典:
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
"callback_name_1": {
|
||||
@ -158,6 +160,7 @@ class CallbackManager:
|
||||
"filter_states": {"on_train_begin": filter1.state_dict(), ...}
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
states = {}
|
||||
|
@ -39,7 +39,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
|
||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种
|
||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最
|
||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor
|
||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。
|
||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。
|
||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。
|
||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是
|
||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有
|
||||
|
@ -10,13 +10,13 @@ class TorchGradClipCallback(Callback):
|
||||
在每次 optimizer update 之前将 parameter 进行 clip
|
||||
|
||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
||||
:param str clip_type: 支持'norm', 'value'两种::
|
||||
:param str clip_type: 支持'norm', 'value'两种:
|
||||
|
||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||
2. 'value', 将gradient限制在[-clip_value, clip_value],
|
||||
小于-clip_value的gradient被赋值为-clip_value;
|
||||
大于clip_value的gradient被赋值为clip_value.
|
||||
|
||||
2 'value', 将gradient限制在[-clip_value, clip_value],
|
||||
小于-clip_value的gradient被赋值为-clip_value;
|
||||
大于clip_value的gradient被赋值为clip_value.
|
||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
|
||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
|
||||
"""
|
||||
|
@ -51,23 +51,30 @@ class Evaluator:
|
||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`;
|
||||
:param fp16: 是否使用 fp16 。
|
||||
:param verbose: 是否打印 evaluate 的结果。
|
||||
:param kwargs:
|
||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout
|
||||
与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
|
||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
|
||||
TODO 还没完成。
|
||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的
|
||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,
|
||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数
|
||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定,
|
||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。
|
||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
|
||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
|
||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||
progress_bar: evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
|
||||
到当前terminal为交互型则使用 rich,否则使用 raw。
|
||||
:param \**kwargs:
|
||||
See below
|
||||
:kwargs:
|
||||
* *model_use_eval_mode* (``bool``) --
|
||||
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的
|
||||
dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论
|
||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。
|
||||
TODO 还没完成。
|
||||
* *auto_tensor_conversion_for_metric* (``Union[bool]``) --
|
||||
是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是
|
||||
paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,当 auto_tensor_conversion_for_metric 为True时,fastNLP 将
|
||||
自动将输出中 paddle 的 tensor (其它非 tensor 的参数不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的
|
||||
输出 tensor 类型通过 driver 来决定,metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,
|
||||
请使用 input_mapping、output_mapping 参数进行。
|
||||
* *use_dist_sampler* --
|
||||
是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
|
||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
|
||||
* *output_from_new_proc* --
|
||||
应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||
* *progress_bar* --
|
||||
evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测
|
||||
到当前terminal为交互型则使用 rich,否则使用 raw。
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
|
@ -67,20 +67,21 @@ class Trainer(TrainerEventTrigger):
|
||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
|
||||
|
||||
:param model: 训练所需要的模型,目前支持 pytorch;
|
||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle
|
||||
等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
|
||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等
|
||||
国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
|
||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;
|
||||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
|
||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你
|
||||
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也
|
||||
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前
|
||||
自己构造 DDP 的多进程场景);
|
||||
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也
|
||||
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前
|
||||
自己构造 DDP 的多进程场景);
|
||||
device 的可选输入如下所示:
|
||||
1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中;
|
||||
2. torch.device:将模型装载到torch.device上;
|
||||
3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`;
|
||||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
|
||||
5. None: 为None则不对模型进行任何处理;
|
||||
|
||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;
|
||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
|
||||
为 None;
|
||||
@ -121,26 +122,27 @@ class Trainer(TrainerEventTrigger):
|
||||
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
|
||||
:param larger_better: monitor 的值是否是越大越好。
|
||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
|
||||
:param kwargs: 一些其它的可能需要的参数;
|
||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||
data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
|
||||
注意如果 model_device 为 None,那么 data_device 不会起作用;
|
||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
|
||||
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
|
||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
|
||||
:param kwargs: 一些其它的可能需要的参数,见下方的说明
|
||||
:kwargs:
|
||||
* *torch_non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
|
||||
注意如果 model_device 为 None,那么 data_device 不会起作用;
|
||||
* *torch_ddp_kwargs* -- 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
|
||||
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
|
||||
* *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
||||
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
|
||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
|
||||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
* *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
||||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
|
||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
|
||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
|
||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
|
||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
|
||||
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
||||
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
|
||||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
|
||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
||||
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
|
||||
"""
|
||||
self.model = model
|
||||
self.marker = marker
|
||||
@ -290,14 +292,14 @@ class Trainer(TrainerEventTrigger):
|
||||
catch_KeyboardInterrupt=None):
|
||||
"""
|
||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint
|
||||
去保存断点重训的文件;
|
||||
去保存断点重训的文件;
|
||||
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。
|
||||
:param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。
|
||||
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。
|
||||
:param resume_from: 从哪个路径下恢复 trainer 的状态
|
||||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。
|
||||
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运
|
||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
|
||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -417,39 +419,42 @@ class Trainer(TrainerEventTrigger):
|
||||
def on(cls, event: Event, marker: Optional[str] = None):
|
||||
r"""
|
||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
|
||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如
|
||||
Trainer.__init__():
|
||||
on_after_trainer_initialized(trainer, driver)
|
||||
Trainer.run():
|
||||
if num_eval_sanity_batch>0:
|
||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
|
||||
on_sanity_check_end(trainer, sanity_check_res)
|
||||
try:
|
||||
on_train_begin(trainer)
|
||||
while cur_epoch_idx < n_epochs:
|
||||
on_train_epoch_begin(trainer)
|
||||
while batch_idx_in_epoch<=num_batches_per_epoch:
|
||||
on_fetch_data_begin(trainer)
|
||||
batch = next(dataloader)
|
||||
on_fetch_data_end(trainer)
|
||||
on_train_batch_begin(trainer, batch, indices)
|
||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
|
||||
on_after_backward(trainer)
|
||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_train_batch_end(trainer)
|
||||
on_train_epoch_end(trainer)
|
||||
except BaseException:
|
||||
self.on_exception(trainer, exception)
|
||||
finally:
|
||||
on_train_end(trainer)
|
||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如::
|
||||
|
||||
Trainer.__init__():
|
||||
on_after_trainer_initialized(trainer, driver)
|
||||
Trainer.run():
|
||||
if num_eval_sanity_batch>0:
|
||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
|
||||
on_sanity_check_end(trainer, sanity_check_res)
|
||||
try:
|
||||
on_train_begin(trainer)
|
||||
while cur_epoch_idx < n_epochs:
|
||||
on_train_epoch_begin(trainer)
|
||||
while batch_idx_in_epoch<=num_batches_per_epoch:
|
||||
on_fetch_data_begin(trainer)
|
||||
batch = next(dataloader)
|
||||
on_fetch_data_end(trainer)
|
||||
on_train_batch_begin(trainer, batch, indices)
|
||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
|
||||
on_after_backward(trainer)
|
||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_train_batch_end(trainer)
|
||||
on_train_epoch_end(trainer)
|
||||
except BaseException:
|
||||
self.on_exception(trainer, exception)
|
||||
finally:
|
||||
on_train_end(trainer)
|
||||
|
||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
|
||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
|
||||
特定的时间调用。
|
||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
|
||||
特定的时间调用。
|
||||
|
||||
Example::
|
||||
|
||||
from fastNLP import Event
|
||||
@Trainer.on(Event.on_save_model())
|
||||
def do_something_1(trainer):
|
||||
@ -696,7 +701,7 @@ class Trainer(TrainerEventTrigger):
|
||||
r"""
|
||||
用于断点重训的加载函数;
|
||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的
|
||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler;
|
||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler;
|
||||
|
||||
注意我们目前不支持单卡到多卡的断点重训;
|
||||
|
||||
|
@ -26,7 +26,8 @@ class State(dict):
|
||||
|
||||
为了实现断点重训,用户应当保证其保存的信息都是可序列化的;
|
||||
|
||||
推荐的使用方式:
|
||||
推荐的使用方式::
|
||||
|
||||
>>> state = State()
|
||||
>>> state["best_accuracy"] = 0.9
|
||||
>>> print(state["best_accuracy"])
|
||||
|
@ -137,6 +137,7 @@ class JittorDataLoader:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Example::
|
||||
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -142,6 +142,7 @@ class PaddleDataLoader(DataLoader):
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Example::
|
||||
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -153,6 +153,7 @@ class TorchDataLoader(DataLoader):
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Example::
|
||||
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -706,8 +706,8 @@ class DataSet:
|
||||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet':
|
||||
"""
|
||||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target
|
||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
|
||||
当前dataset含有field,则会报错。
|
||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有
|
||||
当前dataset含有field,则会报错。
|
||||
|
||||
:param DataSet, dataset: 需要和当前dataset concat的dataset
|
||||
:param bool, inplace: 是否直接将dataset组合到当前dataset中
|
||||
|
@ -87,8 +87,8 @@ class Driver(ABC):
|
||||
|
||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
|
||||
:param fn: 调用该函数进行一次计算。
|
||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
|
||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函
|
||||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
|
||||
"""
|
||||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.")
|
||||
@ -106,9 +106,10 @@ class Driver(ABC):
|
||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;
|
||||
|
||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
|
||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||
函数,然后给出 warning;
|
||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||
函数,然后给出 warning;
|
||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||
|
||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||
@ -376,7 +377,7 @@ class Driver(ABC):
|
||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉;
|
||||
|
||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的
|
||||
pid 的信息;
|
||||
pid 的信息;
|
||||
"""
|
||||
# 单卡 driver 不需要这个函数;
|
||||
if self._pids is not None:
|
||||
|
@ -172,6 +172,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
|
||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||
|
||||
example::
|
||||
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
|
@ -534,7 +534,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
|
||||
"""
|
||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
|
||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据
|
||||
:param int src: source 的 global rank 。
|
||||
@ -551,9 +551,10 @@ class TorchDDPDriver(TorchDriver):
|
||||
def all_gather(self, obj, group) -> List:
|
||||
"""
|
||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
|
||||
example::
|
||||
|
||||
example:
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
|
@ -175,7 +175,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) -
|
||||
"""
|
||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||
|
||||
example:
|
||||
example::
|
||||
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
|
@ -175,16 +175,18 @@ def _build_fp16_env(dummy=False):
|
||||
|
||||
def replace_sampler(dataloader: "DataLoader", sampler):
|
||||
"""
|
||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于):
|
||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于):
|
||||
|
||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接
|
||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader
|
||||
的类,而不是直接的 DataLoader;
|
||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接
|
||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader
|
||||
的类,而不是直接的 DataLoader;
|
||||
|
||||
如果需要定制自己的 dataloader,保证以下两点:
|
||||
|
||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
|
||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
|
||||
来获取实际的参数的值;
|
||||
|
||||
如果需要定制自己的 dataloader,保证以下两点:
|
||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中;
|
||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性
|
||||
来获取实际的参数的值;
|
||||
"""
|
||||
|
||||
# 拿到实例属性;
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
from fastNLP.core.drivers.driver import Driver
|
||||
|
||||
|
||||
__all__ = []
|
||||
|
||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver:
|
||||
r"""
|
||||
|
@ -1,18 +1,20 @@
|
||||
r"""
|
||||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger,
|
||||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API
|
||||
使用方式:
|
||||
from fastNLP import _logger
|
||||
#
|
||||
# _logger 可以和 logging.Logger 一样使用
|
||||
_logger.info('your msg')
|
||||
_logger.error('your msg')
|
||||
|
||||
# _logger 新增的API
|
||||
# 将日志输出到文件,以及输出的日志等级
|
||||
_logger.add_file('/path/to/log', level='INFO')
|
||||
# 定义在命令行中的显示格式和日志等级
|
||||
_logger.set_stdout('tqdm', level='WARN')
|
||||
使用方式::
|
||||
|
||||
from fastNLP import _logger
|
||||
#
|
||||
# _logger 可以和 logging.Logger 一样使用
|
||||
_logger.info('your msg')
|
||||
_logger.error('your msg')
|
||||
|
||||
# _logger 新增的API
|
||||
# 将日志输出到文件,以及输出的日志等级
|
||||
_logger.add_file('/path/to/log', level='INFO')
|
||||
# 定义在命令行中的显示格式和日志等级
|
||||
_logger.set_stdout('tqdm', level='WARN')
|
||||
|
||||
"""
|
||||
|
||||
|
@ -10,12 +10,13 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
|
||||
用来重定向 print 函数至 logger.info 的函数。
|
||||
|
||||
Example::
|
||||
|
||||
from fastNLP import print
|
||||
print("This is a test") # 等价于调用了 logger.info("This is a test")
|
||||
|
||||
:param args: 需要打印的内容
|
||||
:param sep: 存在多个输入时,使用的间隔。
|
||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。
|
||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 '\\\\n' 。
|
||||
:param file: 该参数无意义。
|
||||
:param flush: 该参数无意义。
|
||||
:return:
|
||||
|
@ -38,7 +38,7 @@ class Metric:
|
||||
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
|
||||
"""
|
||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的
|
||||
tensor 直接进行加减乘除计算即可。
|
||||
tensor 直接进行加减乘除计算即可。
|
||||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。
|
||||
|
||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。
|
||||
@ -48,7 +48,7 @@ class Metric:
|
||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。
|
||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入
|
||||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含
|
||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测
|
||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测
|
||||
到任何一种 tensor ,就默认使用 float 类型作为 element 。
|
||||
:return: 注册的 Element 对象
|
||||
"""
|
||||
|
@ -496,7 +496,7 @@ class PollingSampler(MixSampler):
|
||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
|
||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
|
||||
:param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时,
|
||||
以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样
|
||||
以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样
|
||||
"""
|
||||
super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size,
|
||||
sampler=sampler, ds_ratio=ds_ratio,
|
||||
|
@ -35,7 +35,9 @@ class NumConsumedSamplesArray:
|
||||
def __init__(self, buffer_size=2000, num_consumed_samples=0):
|
||||
"""
|
||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
|
||||
|
||||
Example::
|
||||
|
||||
array = NumConsumedSamplesArray(buffer_size=3)
|
||||
for i in range(10):
|
||||
array.push(i)
|
||||
|
@ -222,7 +222,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec
|
||||
|
||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
|
||||
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
|
||||
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称::
|
||||
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称。
|
||||
|
||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
|
||||
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到
|
||||
|
@ -256,12 +256,13 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None,
|
||||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用;
|
||||
|
||||
转换的逻辑按优先级依次为:
|
||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
|
||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
|
||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
|
||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换;
|
||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用
|
||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。
|
||||
|
||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
|
||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
|
||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
|
||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换;
|
||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用
|
||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。
|
||||
|
||||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。
|
||||
:param data: 需要被转换的对象;
|
||||
@ -439,12 +440,16 @@ def _is_iterable(value):
|
||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable:
|
||||
r"""
|
||||
:param dataset_or_ins: 传入一个dataSet或者instance
|
||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
|
||||
+-----------+-----------+-----------------+
|
||||
| field_1 | field_2 | field_3 |
|
||||
+-----------+-----------+-----------------+
|
||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
|
||||
+-----------+-----------+-----------------+
|
||||
|
||||
.. code-block::
|
||||
|
||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
|
||||
+-----------+-----------+-----------------+
|
||||
| field_1 | field_2 | field_3 |
|
||||
+-----------+-----------+-----------------+
|
||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
|
||||
+-----------+-----------+-----------------+
|
||||
|
||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
|
||||
"""
|
||||
x = PrettyTable()
|
||||
|
@ -47,7 +47,7 @@ def rank_zero_call(fn: Callable):
|
||||
rank_zero_call(add)(1, 2)
|
||||
|
||||
同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何
|
||||
意义。
|
||||
意义。
|
||||
|
||||
:param fn: 需要包裹的可执行的函数。
|
||||
:return:
|
||||
@ -65,7 +65,7 @@ def rank_zero_call(fn: Callable):
|
||||
def fastnlp_no_sync_context(level=2):
|
||||
"""
|
||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效;
|
||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。
|
||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。
|
||||
|
||||
:param int level: 可选 [0, 1, 2]
|
||||
:return:
|
||||
@ -84,9 +84,10 @@ def all_rank_call_context():
|
||||
"""
|
||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。
|
||||
|
||||
# 使用方式
|
||||
with all_rank_call_context():
|
||||
do_something # all rank will do
|
||||
使用方式::
|
||||
|
||||
with all_rank_call_context():
|
||||
do_something # all rank will do
|
||||
|
||||
:param fn:
|
||||
:return:
|
||||
|
@ -233,8 +233,8 @@ class DataBundle:
|
||||
如果为False,则报错
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
|
||||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
||||
:param show_progress_bar 是否显示tqdm进度条
|
||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
||||
:param show_progress_bar: 是否显示tqdm进度条
|
||||
|
||||
"""
|
||||
_progress_desc = progress_desc
|
||||
|
Loading…
Reference in New Issue
Block a user