新增train_input_mapping 和 evaluate_input_mapping 等

This commit is contained in:
yh_cc 2022-04-25 16:47:29 +08:00
parent 6ddcedaaeb
commit 7e40d98404

View File

@ -103,10 +103,12 @@ class Trainer(TrainerEventTrigger):
value如果 batch 是一个 `dataclass`那么我们会先将该 dataclass 转换为一个 Dict然后再进行上述转换如果 batch 此时是其它
类型那么我们将会直接报错如果 input_mapping 是一个函数那么对于取出的 batch我们将不会做任何处理而是直接将其传入该函数里
注意该参数会被传进 `Evaluator` 因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作例如当参数 `device` None
如果 train evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping evaluate_input_mapping 设置
:param output_mapping: 应当为一个字典或者函数作用和 input_mapping 类似区别在于其用于转换输出如果 output_mapping 是一个
函数那么我们将会直接将模型的输出传给该函数如果其是一个 `Dict`那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型
如果 batch 是一个 `Dict`那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key value
如果 batch 是一个 `dataclass`那么我们会先将该 dataclass 转换为一个 Dict然后再进行上述转换
如果 train evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping evaluate_output_mapping 设置
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch forward 函数的参数的行为
如果该值为 False并且当 batch 为字典时我们会根据 forward 所需要的参数从 batch 中提取对应的对象传入到 forward 函数中如果该值
True那么我们会将 batch 直接透传给模型注意该参数应用于 `train_step`, `evaluate_step` `test_step`
@ -133,6 +135,10 @@ class Trainer(TrainerEventTrigger):
progress_bar: 以哪种方式显示 progress 目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback否则使用 RawTextCallback对象如果
需要定制 progress bar 的参数例如打印频率等可以传入 RichCallback, RawTextCallback 对象
train_input_mapping: input_mapping 一致但是只用于 train input_mapping 互斥
train_output_mapping: output_mapping 一致但是只用于 train input_mapping 互斥
evaluate_input_mapping: input_mapping 一致但是只用于 evaluate input_mapping 互斥
evaluate_output_mapping: output_mapping 一致但是只用于 evaluate input_mapping 互斥
"""
self.model = model
self.marker = marker
@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger):
self.evaluate_dataloaders = evaluate_dataloaders
self.optimizers = optimizers
self.fp16 = fp16
self.input_mapping = input_mapping
self.output_mapping = output_mapping
train_input_mapping = kwargs.get('train_input_mapping', None)
train_output_mapping = kwargs.get('train_output_mapping', None)
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)
train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
evaluate_input_mapping, evaluate_output_mapping)
self.input_mapping = train_input_mapping
self.output_mapping = train_output_mapping
self.evaluate_fn = evaluate_fn
self.batch_step_fn = batch_step_fn
@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger):
callbacks=callbacks,
metrics=metrics,
evaluate_every=evaluate_every,
input_mapping=input_mapping,
output_mapping=output_mapping,
input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps,
fp16=fp16,
@ -854,6 +870,32 @@ class Trainer(TrainerEventTrigger):
self._evaluate_dataloaders = evaluate_dataloaders
def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
evaluate_input_mapping, evaluate_output_mapping):
if train_input_mapping is not None and input_mapping is not None:
raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.")
if evaluate_input_mapping is not None and input_mapping is not None:
raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.")
if train_output_mapping is not None and output_mapping is not None:
raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.")
if evaluate_output_mapping is not None and output_mapping is not None:
raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.")
if train_input_mapping is None:
train_input_mapping = input_mapping
if evaluate_input_mapping is None:
evaluate_input_mapping = input_mapping
if train_output_mapping is None:
train_output_mapping = output_mapping
if evaluate_output_mapping is None:
evaluate_output_mapping = output_mapping
return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping