mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
新增train_input_mapping 和 evaluate_input_mapping 等
This commit is contained in:
parent
6ddcedaaeb
commit
7e40d98404
@ -103,10 +103,12 @@ class Trainer(TrainerEventTrigger):
|
||||
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
|
||||
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
|
||||
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时);
|
||||
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。
|
||||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
|
||||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
|
||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
|
||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
|
||||
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。
|
||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
|
||||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
|
||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
|
||||
@ -133,6 +135,10 @@ class Trainer(TrainerEventTrigger):
|
||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
|
||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
|
||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
|
||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
|
||||
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
||||
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
|
||||
"""
|
||||
self.model = model
|
||||
self.marker = marker
|
||||
@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger):
|
||||
self.evaluate_dataloaders = evaluate_dataloaders
|
||||
self.optimizers = optimizers
|
||||
self.fp16 = fp16
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
train_input_mapping = kwargs.get('train_input_mapping', None)
|
||||
train_output_mapping = kwargs.get('train_output_mapping', None)
|
||||
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
|
||||
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)
|
||||
|
||||
train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
|
||||
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
|
||||
evaluate_input_mapping, evaluate_output_mapping)
|
||||
|
||||
self.input_mapping = train_input_mapping
|
||||
self.output_mapping = train_output_mapping
|
||||
self.evaluate_fn = evaluate_fn
|
||||
|
||||
self.batch_step_fn = batch_step_fn
|
||||
@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger):
|
||||
callbacks=callbacks,
|
||||
metrics=metrics,
|
||||
evaluate_every=evaluate_every,
|
||||
input_mapping=input_mapping,
|
||||
output_mapping=output_mapping,
|
||||
input_mapping=evaluate_input_mapping,
|
||||
output_mapping=evaluate_output_mapping,
|
||||
model_wo_auto_param_call=model_wo_auto_param_call,
|
||||
accumulation_steps=accumulation_steps,
|
||||
fp16=fp16,
|
||||
@ -854,6 +870,32 @@ class Trainer(TrainerEventTrigger):
|
||||
self._evaluate_dataloaders = evaluate_dataloaders
|
||||
|
||||
|
||||
def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
|
||||
evaluate_input_mapping, evaluate_output_mapping):
|
||||
if train_input_mapping is not None and input_mapping is not None:
|
||||
raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.")
|
||||
|
||||
if evaluate_input_mapping is not None and input_mapping is not None:
|
||||
raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.")
|
||||
|
||||
if train_output_mapping is not None and output_mapping is not None:
|
||||
raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.")
|
||||
|
||||
if evaluate_output_mapping is not None and output_mapping is not None:
|
||||
raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.")
|
||||
|
||||
if train_input_mapping is None:
|
||||
train_input_mapping = input_mapping
|
||||
if evaluate_input_mapping is None:
|
||||
evaluate_input_mapping = input_mapping
|
||||
|
||||
if train_output_mapping is None:
|
||||
train_output_mapping = output_mapping
|
||||
if evaluate_output_mapping is None:
|
||||
evaluate_output_mapping = output_mapping
|
||||
|
||||
return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user