mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
调整部分文档缩进不正确的问题
This commit is contained in:
parent
19708cc89a
commit
4b2263739f
@ -11,37 +11,39 @@ from .callback_event import Event, Filter
|
||||
class Callback:
|
||||
r"""
|
||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类;
|
||||
callback 调用时机顺序大概如下
|
||||
Trainer.__init__():
|
||||
on_after_trainer_initialized(trainer, driver)
|
||||
Trainer.run():
|
||||
if num_eval_sanity_batch>0:
|
||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
|
||||
on_sanity_check_end(trainer, sanity_check_res)
|
||||
try:
|
||||
on_train_begin(trainer)
|
||||
while cur_epoch_idx < n_epochs:
|
||||
on_train_epoch_begin(trainer)
|
||||
while batch_idx_in_epoch<=num_batches_per_epoch:
|
||||
on_fetch_data_begin(trainer)
|
||||
batch = next(dataloader)
|
||||
on_fetch_data_end(trainer)
|
||||
on_train_batch_begin(trainer, batch, indices)
|
||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
|
||||
on_after_backward(trainer)
|
||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_train_batch_end(trainer)
|
||||
on_train_epoch_end(trainer)
|
||||
except BaseException:
|
||||
self.on_exception(trainer, exception)
|
||||
finally:
|
||||
on_train_end(trainer)
|
||||
callback 调用时机顺序大概如下::
|
||||
|
||||
Trainer.__init__():
|
||||
on_after_trainer_initialized(trainer, driver)
|
||||
Trainer.run():
|
||||
if num_eval_sanity_batch>0:
|
||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
|
||||
on_sanity_check_end(trainer, sanity_check_res)
|
||||
try:
|
||||
on_train_begin(trainer)
|
||||
while cur_epoch_idx < n_epochs:
|
||||
on_train_epoch_begin(trainer)
|
||||
while batch_idx_in_epoch<=num_batches_per_epoch:
|
||||
on_fetch_data_begin(trainer)
|
||||
batch = next(dataloader)
|
||||
on_fetch_data_end(trainer)
|
||||
on_train_batch_begin(trainer, batch, indices)
|
||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
|
||||
on_after_backward(trainer)
|
||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
|
||||
on_train_batch_end(trainer)
|
||||
on_train_epoch_end(trainer)
|
||||
except BaseException:
|
||||
self.on_exception(trainer, exception)
|
||||
finally:
|
||||
on_train_end(trainer)
|
||||
|
||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
|
||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定
|
||||
的时间调用。
|
||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定
|
||||
的时间调用。
|
||||
"""
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
@ -123,8 +125,8 @@ class Callback:
|
||||
def on_train_batch_begin(self, trainer, batch, indices):
|
||||
r"""
|
||||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。
|
||||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。
|
||||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。
|
||||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。
|
||||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。
|
||||
|
||||
:param trainer: `fastNLP.Trainer`
|
||||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。
|
||||
@ -136,8 +138,8 @@ class Callback:
|
||||
def on_train_batch_end(self, trainer):
|
||||
"""
|
||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与
|
||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会
|
||||
执行。
|
||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会
|
||||
执行。
|
||||
|
||||
:param trainer:
|
||||
:return:
|
||||
@ -184,7 +186,7 @@ class Callback:
|
||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]):
|
||||
r"""
|
||||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint()
|
||||
的返回值。
|
||||
的返回值。
|
||||
|
||||
:param trainer:
|
||||
:param states:
|
||||
@ -205,7 +207,7 @@ class Callback:
|
||||
def on_after_backward(self, trainer):
|
||||
"""
|
||||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步,
|
||||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。
|
||||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。
|
||||
|
||||
:param trainer:
|
||||
:return:
|
||||
@ -255,7 +257,7 @@ class Callback:
|
||||
def on_evaluate_begin(self, trainer):
|
||||
"""
|
||||
在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后
|
||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。
|
||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。
|
||||
|
||||
:param trainer:
|
||||
:return:
|
||||
@ -294,7 +296,7 @@ class Callback:
|
||||
class _CallbackWrapper(Callback):
|
||||
"""
|
||||
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的
|
||||
这一个 callback 函数;
|
||||
这一个 callback 函数;
|
||||
"""
|
||||
def __init__(self, event: Event, fn: Callable):
|
||||
r"""
|
||||
|
@ -42,7 +42,7 @@ class Event:
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
|
||||
filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
|
||||
"""
|
||||
self.every = every
|
||||
self.once = once
|
||||
@ -59,6 +59,7 @@ class Event:
|
||||
当 Trainer 运行到 on_after_trainer_initialized 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -74,6 +75,7 @@ class Event:
|
||||
当 Trainer 运行到 on_sanity_check_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -89,6 +91,7 @@ class Event:
|
||||
当 Trainer 运行到 on_sanity_check_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -104,6 +107,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -119,6 +123,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -134,6 +139,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_epoch_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -149,6 +155,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_epoch_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -164,6 +171,7 @@ class Event:
|
||||
当 Trainer 运行到 on_fetch_data_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -179,6 +187,7 @@ class Event:
|
||||
当 Trainer 运行到 on_fetch_data_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -194,6 +203,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_batch_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -209,6 +219,7 @@ class Event:
|
||||
当 Trainer 运行到 on_train_batch_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -224,6 +235,7 @@ class Event:
|
||||
当 Trainer 运行到 on_exception 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -239,6 +251,7 @@ class Event:
|
||||
当 Trainer 运行到 on_save_model 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -254,6 +267,7 @@ class Event:
|
||||
当 Trainer 运行到 on_load_model 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -269,6 +283,7 @@ class Event:
|
||||
当 Trainer 运行到 on_save_checkpoint 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -284,6 +299,7 @@ class Event:
|
||||
当 Trainer 运行到 on_load_checkpoint 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -299,6 +315,7 @@ class Event:
|
||||
当 Trainer 运行到 on_load_checkpoint 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -314,6 +331,7 @@ class Event:
|
||||
当 Trainer 运行到 on_before_backward 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -329,6 +347,7 @@ class Event:
|
||||
当 Trainer 运行到 on_after_backward 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -344,6 +363,7 @@ class Event:
|
||||
当 Trainer 运行到 on_before_optimizers_step 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -359,6 +379,7 @@ class Event:
|
||||
当 Trainer 运行到 on_after_optimizers_step 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -374,6 +395,7 @@ class Event:
|
||||
当 Trainer 运行到 on_before_zero_grad 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -389,6 +411,7 @@ class Event:
|
||||
当 Trainer 运行到 on_after_zero_grad 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -404,6 +427,7 @@ class Event:
|
||||
当 Trainer 运行到 on_evaluate_begin 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
@ -419,6 +443,7 @@ class Event:
|
||||
当 Trainer 运行到 on_evaluate_end 时
|
||||
|
||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
|
||||
|
||||
:param int every: 触发了多少次,才真正运行一次。
|
||||
:param bool once: 是否只在第一次运行后就不再执行了。
|
||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
|
||||
|
@ -110,7 +110,7 @@ class CallbackManager:
|
||||
def initialize_class_callbacks(self):
|
||||
r"""
|
||||
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是
|
||||
一个个 callback 时机,也就是 `Event` 的类别;
|
||||
一个个 callback 时机,也就是 `Event` 的类别;
|
||||
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中;
|
||||
|
||||
:param callbacks:
|
||||
@ -150,7 +150,8 @@ class CallbackManager:
|
||||
断点重训应当保存的状态;
|
||||
2. 每一个具体的 callback 函数的 filter 的状态;
|
||||
|
||||
:return: 一个包含上述内容的字典;
|
||||
:return: 一个包含上述内容的字典::
|
||||
|
||||
{
|
||||
"callback_name_1": {
|
||||
"states": {...},
|
||||
|
@ -19,15 +19,15 @@ class CheckpointCallback(Callback):
|
||||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model',
|
||||
save_evaluate_results=True, **kwargs):
|
||||
"""
|
||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下
|
||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下::
|
||||
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型
|
||||
- {save_object}-last/ # 最后一个 epoch 的保存
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型
|
||||
- {save_object}-last/ # 最后一个 epoch 的保存
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
|
||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
|
||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。
|
||||
|
@ -20,14 +20,15 @@ class MoreEvaluateCallback(HasMonitorCallback):
|
||||
**kwargs):
|
||||
"""
|
||||
当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到
|
||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer
|
||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及
|
||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。
|
||||
一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer
|
||||
无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及
|
||||
topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。
|
||||
|
||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存::
|
||||
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
|
||||
:param dataloaders: 需要评估的数据
|
||||
:param metrics: 使用的 metrics 。
|
||||
|
@ -19,10 +19,11 @@ class Saver:
|
||||
def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True,
|
||||
model_save_fn:Callable=None, **kwargs):
|
||||
"""
|
||||
执行保存的对象。保存的文件组织结构为
|
||||
- folder # 当前初始化的参数
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- folder_name # 由 save() 调用时传入。
|
||||
执行保存的对象。保存的文件组织结构为::
|
||||
|
||||
- folder # 当前初始化的参数
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- folder_name # 由 save() 调用时传入。
|
||||
|
||||
:param folder: 保存在哪个文件夹下,默认为当前 folder 下。
|
||||
:param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。
|
||||
@ -53,10 +54,11 @@ class Saver:
|
||||
@rank_zero_call
|
||||
def save(self, trainer, folder_name):
|
||||
"""
|
||||
执行保存的函数,将数据保存在
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- folder_name # 当前函数参数
|
||||
执行保存的函数,将数据保存在::
|
||||
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- folder_name # 当前函数参数
|
||||
|
||||
:param trainer: Trainer 对象
|
||||
:param folder_name: 保存的 folder 名称,将被创建。
|
||||
@ -129,8 +131,8 @@ class TopkQueue:
|
||||
def push(self, key, value) -> Optional[Tuple[Union[str, None], Union[float, None]]]:
|
||||
"""
|
||||
将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给
|
||||
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回
|
||||
推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。
|
||||
挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回
|
||||
推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。
|
||||
|
||||
:param str key:
|
||||
:param float value: 如果为 None, 则不做任何操作。
|
||||
@ -173,10 +175,11 @@ class TopkSaver(MonitorUtility, Saver):
|
||||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True,
|
||||
**kwargs):
|
||||
"""
|
||||
用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为::
|
||||
|
||||
- folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
|
||||
:param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。
|
||||
:param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用
|
||||
@ -208,7 +211,7 @@ class TopkSaver(MonitorUtility, Saver):
|
||||
def save_topk(self, trainer, results: Dict) -> Optional[str]:
|
||||
"""
|
||||
根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足
|
||||
topk 要求,没有发生保存。
|
||||
topk 要求,没有发生保存。
|
||||
|
||||
:param trainer:
|
||||
:param results: evaluate 的结果。
|
||||
|
@ -10,8 +10,7 @@ class TorchGradClipCallback(Callback):
|
||||
在每次 optimizer update 之前将 parameter 进行 clip
|
||||
|
||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
||||
:param str clip_type: 支持'norm', 'value'
|
||||
两种::
|
||||
:param str clip_type: 支持'norm', 'value'两种::
|
||||
|
||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Tuple
|
||||
import os
|
||||
|
||||
from fastNLP.core.log.logger import logger
|
||||
@ -6,10 +6,10 @@ from difflib import SequenceMatcher
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
|
||||
|
||||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float):
|
||||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]:
|
||||
"""
|
||||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行
|
||||
匹配。
|
||||
匹配。
|
||||
|
||||
:param monitor:
|
||||
:param real_monitor:
|
||||
|
@ -84,8 +84,8 @@ class Collator:
|
||||
def __init__(self, backend='auto'):
|
||||
"""
|
||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
|
||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
|
||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。
|
||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
|
||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。
|
||||
|
||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
|
||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad
|
||||
@ -101,8 +101,7 @@ class Collator:
|
||||
|
||||
def __call__(self, batch)->Union[List, Dict]:
|
||||
"""
|
||||
batch可能存在三种可能性
|
||||
List[Dict], List[List], List[Sample]
|
||||
batch可能存在三种可能性:List[Dict], List[List], List[Sample]
|
||||
|
||||
第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
|
||||
第二步:使用每个 field 各自的 padder 进行 pad 。
|
||||
@ -264,7 +263,8 @@ class Collator:
|
||||
def set_ignore(self, *field_names) -> "Collator":
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
Example::
|
||||
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -62,8 +62,8 @@ class Trainer(TrainerEventTrigger):
|
||||
):
|
||||
r"""
|
||||
`Trainer` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
|
||||
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需
|
||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
|
||||
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需
|
||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
|
||||
|
||||
:param model: 训练所需要的模型,目前支持 pytorch;
|
||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle
|
||||
|
@ -56,12 +56,12 @@ class TrainerState:
|
||||
我们保存的state大部分上是 trainer 断点重训 需要重新加载的;
|
||||
专属于 `Trainer` 的状态记载的类;
|
||||
|
||||
n_epochs: 训练过程中总共的 epoch 的数量;
|
||||
cur_epoch_idx: 当前正在运行第几个 epoch;
|
||||
global_forward_batches: 当前模型总共 forward 了多少个 step;
|
||||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
|
||||
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
|
||||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
|
||||
:param n_epochs: 训练过程中总共的 epoch 的数量;
|
||||
:param cur_epoch_idx: 当前正在运行第几个 epoch;
|
||||
:param global_forward_batches: 当前模型总共 forward 了多少个 step;
|
||||
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
|
||||
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
|
||||
:param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
|
||||
"""
|
||||
n_epochs: Optional[int] = None # 无论如何重新算
|
||||
|
||||
|
@ -136,7 +136,7 @@ class JittorDataLoader:
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
Example::
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -141,7 +141,7 @@ class PaddleDataLoader(DataLoader):
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
Example::
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -152,7 +152,7 @@ class TorchDataLoader(DataLoader):
|
||||
def set_ignore(self, *field_names) -> Collator:
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
Example::
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -399,7 +399,7 @@ class DataSet:
|
||||
raise KeyError("DataSet has no field named {}.".format(field_name))
|
||||
return self
|
||||
|
||||
def apply_field(self, func: Union[Callable], field_name: str = None,
|
||||
def apply_field(self, func: Callable, field_name: str = None,
|
||||
new_field_name: str = None, num_proc: int = 0,
|
||||
progress_desc: str = None, show_progress_bar: bool = True):
|
||||
r"""
|
||||
@ -435,7 +435,7 @@ class DataSet:
|
||||
func 可以返回一个或多个 field 上的结果。
|
||||
|
||||
.. note::
|
||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply`` 区别的介绍。
|
||||
|
||||
:param num_proc: 进程的数量
|
||||
|
@ -17,6 +17,8 @@ class Instance(Mapping):
|
||||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。
|
||||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示::
|
||||
|
||||
instance = Instance() # 请补充完整
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **fields):
|
||||
|
@ -49,8 +49,8 @@ class Driver(ABC):
|
||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
|
||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||
可以可以加载。
|
||||
@ -66,13 +66,13 @@ class Driver(ABC):
|
||||
def set_deterministic_dataloader(self, dataloader):
|
||||
r"""
|
||||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 torch 的 dataloader,其
|
||||
需要将 worker_init_fn 替换;
|
||||
需要将 worker_init_fn 替换;
|
||||
"""
|
||||
|
||||
def set_sampler_epoch(self, dataloader, cur_epoch_idx):
|
||||
r"""
|
||||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;
|
||||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。
|
||||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。
|
||||
|
||||
:param dataloader: 需要设置 epoch 的 dataloader 。
|
||||
:param cur_epoch_idx: 当前是第几个 epoch;
|
||||
@ -101,17 +101,17 @@ class Driver(ABC):
|
||||
|
||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
|
||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
|
||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
|
||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
|
||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;
|
||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
|
||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
|
||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;
|
||||
|
||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
|
||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||
函数,然后给出 warning;
|
||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||
|
||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
|
||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
|
||||
@ -202,7 +202,7 @@ class Driver(ABC):
|
||||
def get_model_no_sync_context(self):
|
||||
r"""
|
||||
返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,
|
||||
单卡的 driver 不需要;
|
||||
单卡的 driver 不需要;
|
||||
|
||||
:return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象;
|
||||
"""
|
||||
@ -273,7 +273,7 @@ class Driver(ABC):
|
||||
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
r"""
|
||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和;
|
||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。
|
||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。
|
||||
|
||||
该函数应该在所有 rank 上执行。
|
||||
|
||||
@ -302,7 +302,7 @@ class Driver(ABC):
|
||||
def tensor_to_numeric(tensor, reduce: Optional[str]=None):
|
||||
r"""
|
||||
将一个 `tensor` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 `numeric` 对象;如果 tensor 只包含一个
|
||||
元素则返回 float 或 int 。
|
||||
元素则返回 float 或 int 。
|
||||
|
||||
:param tensor: 需要被转换的 `tensor` 对象
|
||||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回
|
||||
@ -323,7 +323,7 @@ class Driver(ABC):
|
||||
"""
|
||||
保证用户拿到的模型一定是最原始的模型;
|
||||
注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是
|
||||
最为原始的模型;
|
||||
最为原始的模型;
|
||||
需要注意用户本身传入的模型就是经过类似 `torch.nn.DataParallel` 或者 `torch.nn.parallel.DistributedDataParallel` 包裹的模型,
|
||||
因此在该函数内需要先判断模型的类别;
|
||||
|
||||
@ -335,7 +335,7 @@ class Driver(ABC):
|
||||
r"""
|
||||
用来将模型转移到指定的 device 上;
|
||||
之所以写成 `staticmethod`,是因为一方面在 `Driver` 中我们要使用 `unwrap_model` 来拿到最原始的模型,另一方面,在 `save_model`
|
||||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数;
|
||||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数;
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -373,7 +373,7 @@ class Driver(ABC):
|
||||
def on_exception(self):
|
||||
"""
|
||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程
|
||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉;
|
||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉;
|
||||
|
||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的
|
||||
pid 的信息;
|
||||
@ -399,7 +399,7 @@ class Driver(ABC):
|
||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
|
||||
"""
|
||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )broadcast 到其它所有进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
|
||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据
|
||||
:param int src: source 的 global rank 。
|
||||
@ -415,7 +415,7 @@ class Driver(ABC):
|
||||
def all_gather(self, obj, group)->List:
|
||||
"""
|
||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
|
||||
:param obj: 可以是 float/int/bool/np.ndarray/{}/[]/Tensor等。
|
||||
:param group:
|
||||
|
@ -171,7 +171,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
|
||||
"""
|
||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||
|
||||
example:
|
||||
example::
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
|
@ -379,13 +379,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
self._has_fleetwrapped = True
|
||||
|
||||
def on_exception(self):
|
||||
"""
|
||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程
|
||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉;
|
||||
|
||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的
|
||||
pid 的信息;
|
||||
"""
|
||||
rank_zero_rm(self.gloo_rendezvous_dir)
|
||||
super().on_exception()
|
||||
|
||||
@ -420,17 +413,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
return self.model_device
|
||||
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
"""
|
||||
通过调用 `fn` 来实现训练时的前向传播过程;
|
||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
|
||||
函数;
|
||||
|
||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
|
||||
:param fn: 调用该函数进行一次计算。
|
||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
|
||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
|
||||
"""
|
||||
if self._has_fleetwrapped:
|
||||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
|
||||
wo_auto_param_call=self.wo_auto_param_call)
|
||||
@ -441,27 +423,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
return fn(batch)
|
||||
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
"""
|
||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
|
||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;
|
||||
|
||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
|
||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
|
||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
|
||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
|
||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;
|
||||
|
||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
|
||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||
函数,然后给出 warning;
|
||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||
|
||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
|
||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
|
||||
"""
|
||||
model = self.unwrap_model()
|
||||
if self._has_fleetwrapped:
|
||||
if hasattr(model, fn):
|
||||
@ -487,24 +448,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
|
||||
reproducible: bool = False):
|
||||
r"""
|
||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。
|
||||
|
||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
|
||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
|
||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
|
||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
|
||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||
可以可以加载。
|
||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
|
||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
|
||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
|
||||
"""
|
||||
# 暂时不支持iterableDataset
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
"FastNLP does not support `IteratorDataset` now."
|
||||
@ -619,43 +562,9 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
|
||||
"""
|
||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
|
||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据
|
||||
:param int src: source 的 global rank 。
|
||||
:param int dst: target 的 global rank,可以是多个目标 rank
|
||||
:param group: 所属的 group
|
||||
:param kwargs:
|
||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
|
||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
||||
"""
|
||||
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误
|
||||
device = get_device_from_visible(self.data_device)
|
||||
return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group)
|
||||
|
||||
def all_gather(self, obj, group=None) -> List:
|
||||
"""
|
||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
|
||||
example:
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
'c': {
|
||||
'd': [1, 2]
|
||||
}
|
||||
}
|
||||
->
|
||||
[
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
|
||||
]
|
||||
|
||||
:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
return fastnlp_paddle_all_gather(obj, group=group)
|
||||
|
@ -47,7 +47,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
|
||||
class PaddleDriver(Driver):
|
||||
r"""
|
||||
Paddle框架的Driver,包括实现单卡训练的`PaddleSingleDriver`和分布式训练的`PaddleFleetDriver`。
|
||||
Paddle框架的Driver,包括实现单卡训练的 `PaddleSingleDriver` 和分布式训练的 `PaddleFleetDriver`。
|
||||
"""
|
||||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs):
|
||||
if not isinstance(model, paddle.nn.Layer):
|
||||
@ -131,8 +131,7 @@ class PaddleDriver(Driver):
|
||||
@staticmethod
|
||||
def tensor_to_numeric(tensor, reduce=None):
|
||||
r"""
|
||||
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个
|
||||
元素则返回 float 或 int 。
|
||||
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。
|
||||
|
||||
:param tensor: 需要被转换的 `tensor` 对象
|
||||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回
|
||||
@ -158,11 +157,6 @@ class PaddleDriver(Driver):
|
||||
)
|
||||
|
||||
def set_model_mode(self, mode: str):
|
||||
r"""
|
||||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式;
|
||||
|
||||
:param mode: 应为二者之一:["train", "eval"];
|
||||
"""
|
||||
assert mode in {"train", "eval"}
|
||||
getattr(self.model, mode)()
|
||||
|
||||
@ -179,7 +173,6 @@ class PaddleDriver(Driver):
|
||||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save`
|
||||
的文档:
|
||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save
|
||||
:return:
|
||||
"""
|
||||
model = self.unwrap_model()
|
||||
if isinstance(filepath, Path):
|
||||
@ -196,12 +189,12 @@ class PaddleDriver(Driver):
|
||||
|
||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
r"""
|
||||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数;
|
||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。
|
||||
|
||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名);
|
||||
:param only_state_dict: 是否加载state_dict,默认为True。
|
||||
:param kwargs:
|
||||
:return:
|
||||
:param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath
|
||||
模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。
|
||||
:return: 返回加载指定文件后的结果;
|
||||
"""
|
||||
model = self.unwrap_model()
|
||||
if isinstance(filepath, Path):
|
||||
@ -216,22 +209,6 @@ class PaddleDriver(Driver):
|
||||
|
||||
@rank_zero_call
|
||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
r"""
|
||||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict;
|
||||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver
|
||||
本身自己的任何状态;而每一个 driver 实例需要在该函数中实现保存模型和 optimizers 的 state_dict 的逻辑;同时妥善存储传入的
|
||||
states 中的内容(主要用于恢复 Trainer ,Callback 等)
|
||||
需要保证该函数只在 global rank 0 上运行
|
||||
|
||||
:param folder: 保存断点重训的状态的文件名;
|
||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
|
||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的
|
||||
传入的值保持一致。
|
||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
|
||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
|
||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
|
||||
:return:
|
||||
"""
|
||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
|
||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
|
||||
|
||||
@ -422,19 +399,10 @@ class PaddleDriver(Driver):
|
||||
random.seed(stdlib_seed)
|
||||
|
||||
def set_deterministic_dataloader(self, dataloader):
|
||||
r"""
|
||||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;
|
||||
作用是替换 datalaoder 的 `worker_init_fn`。
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
|
||||
dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank)
|
||||
|
||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
|
||||
r"""
|
||||
对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;
|
||||
|
||||
:param cur_epoch_idx: 当前是第几个 epoch;
|
||||
"""
|
||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
|
||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx)
|
||||
|
||||
|
@ -73,44 +73,12 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
self.model.to(device)
|
||||
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
"""
|
||||
通过调用 `fn` 来实现训练时的前向传播过程;
|
||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的
|
||||
函数;
|
||||
|
||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型;
|
||||
:param fn: 调用该函数进行一次计算。
|
||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call
|
||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward;
|
||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查);
|
||||
"""
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return fn(batch)
|
||||
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
"""
|
||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数;
|
||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用;
|
||||
|
||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上;
|
||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和
|
||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和
|
||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是
|
||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中;
|
||||
|
||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示:
|
||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward`
|
||||
函数,然后给出 warning;
|
||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错;
|
||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的
|
||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此
|
||||
可能需要额外标记最初传入 driver 的模型是哪种形式的;
|
||||
|
||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法;
|
||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入;
|
||||
"""
|
||||
if hasattr(self.model, fn):
|
||||
fn = getattr(self.model, fn)
|
||||
if not callable(fn):
|
||||
@ -125,24 +93,6 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
|
||||
reproducible: bool = False):
|
||||
r"""
|
||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。
|
||||
|
||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
|
||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
|
||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
|
||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
|
||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||
可以可以加载。
|
||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
|
||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
|
||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
|
||||
"""
|
||||
|
||||
# 暂时不支持iterableDataset
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
@ -187,7 +137,7 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
@property
|
||||
def data_device(self):
|
||||
"""
|
||||
单卡模式不支持 data_device;
|
||||
返回数据所在的设备。由于单卡模式不支持 data_device,因此返回的是 model_device
|
||||
"""
|
||||
return self.model_device
|
||||
|
||||
|
@ -9,9 +9,8 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
|
||||
"""
|
||||
用来重定向 print 函数至 logger.info 的函数。
|
||||
|
||||
Example:
|
||||
Example::
|
||||
from fastNLP import print
|
||||
|
||||
print("This is a test") # 等价于调用了 logger.info("This is a test")
|
||||
|
||||
:param args: 需要打印的内容
|
||||
|
@ -8,7 +8,7 @@ from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, Unrepeat
|
||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler):
|
||||
"""
|
||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
|
||||
ReproducibleSampler,
|
||||
ReproducibleSampler,
|
||||
|
||||
:param sampler:
|
||||
:return:
|
||||
|
@ -299,7 +299,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
def total_size(self):
|
||||
"""
|
||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
|
||||
大于或者小于len(dataset)
|
||||
大于或者小于len(dataset)
|
||||
|
||||
:return:
|
||||
"""
|
||||
@ -367,7 +367,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
|
||||
"""
|
||||
首先按照 sample 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,sample 只会在这个桶内进行组合,这样
|
||||
每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。
|
||||
每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
|
||||
@ -440,7 +440,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
def total_size(self):
|
||||
"""
|
||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
|
||||
大于或者小于len(dataset)
|
||||
大于或者小于len(dataset)
|
||||
|
||||
:return:
|
||||
"""
|
||||
|
@ -19,7 +19,7 @@ class ReproducibleSampler:
|
||||
可复现的 Sampler 对象。
|
||||
|
||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
|
||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。
|
||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
@ -87,7 +87,7 @@ class RandomSampler(ReproducibleSampler):
|
||||
def __iter__(self):
|
||||
r"""
|
||||
当前使用num_consumed_samples做法会在交替使用的时候遇到问题;
|
||||
Example:
|
||||
Example::
|
||||
>>> sampler = RandomSampler()
|
||||
>>> iter1 = iter(sampler)
|
||||
>>> iter2 = iter(sampler)
|
||||
|
@ -99,7 +99,7 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
|
||||
def __init__(self, dataset, length:Union[str, List], **kwargs):
|
||||
"""
|
||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的
|
||||
batch 数量不完全一致。
|
||||
batch 数量不完全一致。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
|
||||
|
@ -35,7 +35,7 @@ class NumConsumedSamplesArray:
|
||||
def __init__(self, buffer_size=2000, num_consumed_samples=0):
|
||||
"""
|
||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
|
||||
ex:
|
||||
Example::
|
||||
array = NumConsumedSamplesArray(buffer_size=3)
|
||||
for i in range(10):
|
||||
array.push(i)
|
||||
|
@ -17,7 +17,8 @@ from .utils import apply_to_collection
|
||||
class TorchTransferableDataType(ABC):
|
||||
"""
|
||||
A custom type for data that can be moved to a torch device via `.to(...)`.
|
||||
Example:
|
||||
Example::
|
||||
|
||||
>>> isinstance(dict, TorchTransferableDataType)
|
||||
False
|
||||
>>> isinstance(torch.rand(2, 3), TorchTransferableDataType)
|
||||
|
@ -52,11 +52,11 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
|
||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
|
||||
r"""
|
||||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping
|
||||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。
|
||||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。
|
||||
|
||||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用;
|
||||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来;
|
||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
|
||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
|
||||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
|
||||
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同;
|
||||
|
||||
@ -68,7 +68,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
|
||||
|
||||
:return: 返回 `fn` 运行的结果;
|
||||
|
||||
Examples:
|
||||
Examples::
|
||||
>>> # 1
|
||||
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred);
|
||||
>>> batch = {"x": 20, "y": 1}
|
||||
@ -190,7 +190,7 @@ def _get_fun_msg(fn, with_fp=True)->str:
|
||||
def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
|
||||
"""
|
||||
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会
|
||||
进行报错。
|
||||
进行报错。
|
||||
|
||||
:param fn: 需要检测的函数,可以是 method 或者 function 。
|
||||
:param expected_params: 期待应该支持的参数。
|
||||
|
@ -20,7 +20,7 @@ def is_cur_env_distributed() -> bool:
|
||||
"""
|
||||
单卡模式该函数一定返回 False;
|
||||
注意进程 0 在多卡的训练模式下前后的值是不一样的,例如在开启多卡的 driver 之前,在进程 0 上的该函数返回 False;但是在开启后,在进程 0 上
|
||||
的该函数返回的值是 True;
|
||||
的该函数返回的值是 True;
|
||||
多卡模式下除了进程 0 外的其它进程返回的值一定是 True;
|
||||
"""
|
||||
return FASTNLP_GLOBAL_RANK in os.environ
|
||||
@ -34,12 +34,14 @@ def rank_zero_call(fn: Callable):
|
||||
"""
|
||||
通过该函数包裹的函数,在单卡模式下该方法不影响任何东西,在多卡状态下仅会在 global rank 为 0 的进程下执行。使用方式有两种
|
||||
|
||||
# 使用方式1
|
||||
使用方式1::
|
||||
|
||||
@rank_zero_call
|
||||
def save_model():
|
||||
do_something # will only run in global rank 0
|
||||
|
||||
# 使用方式2
|
||||
使用方式2::
|
||||
|
||||
def add(a, b):
|
||||
return a+b
|
||||
rank_zero_call(add)(1, 2)
|
||||
@ -103,7 +105,7 @@ def all_rank_call_context():
|
||||
def rank_zero_rm(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
|
||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
:param path:
|
||||
|
@ -223,7 +223,7 @@ class DataBundle:
|
||||
def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0,
|
||||
ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True):
|
||||
r"""
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :method:`~fastNLP.DataSet.apply_field` 方法
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法
|
||||
|
||||
:param callable func: input是instance中名为 `field_name` 的field的内容。
|
||||
:param str field_name: 传入func的是哪个field。
|
||||
@ -233,8 +233,8 @@ class DataBundle:
|
||||
如果为False,则报错
|
||||
:param ignore_miss_dataset:
|
||||
:param num_proc:
|
||||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
||||
:param show_progress_bar 是否显示tqdm进度条
|
||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
|
||||
:param show_progress_bar: 是否显示tqdm进度条
|
||||
|
||||
"""
|
||||
_progress_desc = progress_desc
|
||||
@ -251,10 +251,10 @@ class DataBundle:
|
||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
|
||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True):
|
||||
r"""
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
|
||||
|
||||
.. note::
|
||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply`` 区别的介绍。
|
||||
|
||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
@ -285,7 +285,7 @@ class DataBundle:
|
||||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
|
||||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None):
|
||||
r"""
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
|
||||
|
||||
对DataBundle中所有的dataset使用apply方法
|
||||
|
||||
@ -295,7 +295,7 @@ class DataBundle:
|
||||
:param _apply_field:
|
||||
:param show_progress_bar: 是否显示tqd进度条
|
||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
|
||||
:param num_proc
|
||||
:param num_proc:
|
||||
|
||||
"""
|
||||
_progress_desc = progress_desc
|
||||
@ -309,17 +309,17 @@ class DataBundle:
|
||||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
|
||||
progress_desc: str = '', show_progress_bar: bool = True):
|
||||
r"""
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法
|
||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
|
||||
|
||||
.. note::
|
||||
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
|
||||
``apply`` 区别的介绍。
|
||||
|
||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
|
||||
:param show_progress_bar: 是否显示tqd进度条
|
||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
|
||||
:param num_proc
|
||||
:param num_proc:
|
||||
|
||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
|
||||
"""
|
||||
@ -359,7 +359,8 @@ class DataBundle:
|
||||
def set_ignore(self, *field_names) -> "DataBundle":
|
||||
"""
|
||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
|
||||
Ex::
|
||||
Example::
|
||||
|
||||
collator.set_ignore('field1', 'field2')
|
||||
|
||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
|
||||
|
@ -1,4 +1,7 @@
|
||||
r"""undocumented"""
|
||||
r"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"ExtCNNDMLoader"
|
||||
@ -19,9 +22,9 @@ class ExtCNNDMLoader(JsonLoader):
|
||||
.. csv-table::
|
||||
:header: "text", "summary", "label", "publication"
|
||||
|
||||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
|
||||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
|
||||
["..."], ["..."], [], "cnndm"
|
||||
"['I got new tires from them and... ','...']", "['The new tires...','...']", "[0, 1]", "cnndm"
|
||||
"['Don't waste your time. We had two...','...']", "['Time is precious','...']", "[1]", "cnndm"
|
||||
"["..."]", "["..."]", "[]", "cnndm"
|
||||
|
||||
"""
|
||||
|
||||
|
@ -87,7 +87,7 @@ class CLSBasePipe(Pipe):
|
||||
|
||||
def process_from_file(self, paths) -> DataBundle:
|
||||
r"""
|
||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`
|
||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
|
||||
|
||||
:param paths:
|
||||
:return: DataBundle
|
||||
|
@ -164,7 +164,7 @@ class GraphBuilderBase:
|
||||
|
||||
def build_graph_from_file(self, path: str):
|
||||
r"""
|
||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`
|
||||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
|
||||
|
||||
:param path:
|
||||
:return: scipy_sparse_matrix
|
||||
|
@ -33,7 +33,7 @@ class Pipe:
|
||||
|
||||
def process_from_file(self, paths: str) -> DataBundle:
|
||||
r"""
|
||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`
|
||||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
|
||||
|
||||
:param str paths:
|
||||
:return: DataBundle
|
||||
|
@ -53,7 +53,7 @@ class ExtCNNDMPipe(Pipe):
|
||||
|
||||
:param data_bundle:
|
||||
:return: 处理得到的数据包括
|
||||
.. csv-table::
|
||||
.. csv-table::
|
||||
:header: "text_wd", "words", "seq_len", "target"
|
||||
|
||||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0]
|
||||
|
@ -40,6 +40,7 @@ class MixModule:
|
||||
def named_parameters(self, prefix='', recurse: bool=True, backend=None):
|
||||
"""
|
||||
返回模型的名字和参数
|
||||
|
||||
:param prefix: 输出时在参数名前加上的前缀
|
||||
:param recurse: 是否递归地输出参数
|
||||
:param backend: `backend`=`None`时,将所有模型和张量的参数返回;
|
||||
@ -68,6 +69,7 @@ class MixModule:
|
||||
def parameters(self, recurse: bool = True, backend: str = None):
|
||||
"""
|
||||
返回模型的参数
|
||||
|
||||
:param recurse:
|
||||
:param backend: `backend`=`None`时,将所有模型和张量的参数返回;
|
||||
`backend`=`torch`时,返回`torch`的参数;
|
||||
@ -129,7 +131,9 @@ class MixModule:
|
||||
def state_dict(self, backend: str = None) -> Dict:
|
||||
"""
|
||||
返回模型的state_dict。
|
||||
NOTE: torch的destination参数会在将来删除,因此不提供destination参数
|
||||
|
||||
.. note:: torch的destination参数会在将来删除,因此不提供destination参数
|
||||
|
||||
:param backend: `backend`=`None`时,将所有模型和张量的state dict返回;
|
||||
`backend`=`torch`时,返回`torch`的state dict;
|
||||
`backend`=`paddle`时,返回`paddle`的state dict。
|
||||
|
@ -156,6 +156,7 @@ def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'ji
|
||||
def torch2paddle(torch_in: Any, target_device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的torch张量转换为paddle张量
|
||||
|
||||
:param torch_in: 要转换的包含torch.Tensor类型的变量
|
||||
:param target_device: 是否将转换后的张量迁移到特定设备上,
|
||||
输入为`None`时,和输入的张量相同,
|
||||
@ -176,6 +177,7 @@ def torch2paddle(torch_in: Any, target_device: str = None, no_gradient: bool = N
|
||||
def paddle2torch(paddle_in: Any, target_device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的paddle张量转换为torch张量
|
||||
|
||||
:param torch_in: 要转换的包含paddle.Tensor类型的变量
|
||||
:param target_device: 是否将转换后的张量迁移到特定设备上,
|
||||
输入为`None`时,和输入的张量相同,
|
||||
@ -196,6 +198,7 @@ def paddle2torch(paddle_in: Any, target_device: str = None, no_gradient: bool =
|
||||
def jittor2torch(jittor_in: Any, target_device: str = None, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的jittor变量转换为torch张量
|
||||
|
||||
:param jittor_in: 要转换的jittor变量
|
||||
:param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,默认为cuda:0。
|
||||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致;
|
||||
@ -215,6 +218,7 @@ def jittor2torch(jittor_in: Any, target_device: str = None, no_gradient: bool =
|
||||
def torch2jittor(torch_in: Any, no_gradient: bool = None) -> Any:
|
||||
"""
|
||||
递归地将输入中包含的torch张量转换为jittor变量
|
||||
|
||||
:param torch_tensor: 要转换的torch张量
|
||||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致;
|
||||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。
|
||||
|
Loading…
Reference in New Issue
Block a user