mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
checkpoint callback 加入了 on_after_trainer_initialized 的逻辑
This commit is contained in:
parent
afb87b4375
commit
57caf1d028
@ -48,8 +48,9 @@ class CheckpointCallback(Callback):
|
||||
model_save_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if monitor is None and save_topk is not None:
|
||||
raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.")
|
||||
# 我们新加了逻辑,如果 checkpoint callback 自己没有设置 monitor 和 larger_better,那么我们会将其在 trainer 中的设置赋值给它们;
|
||||
# if monitor is None and save_topk is not None:
|
||||
# raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.")
|
||||
|
||||
if monitor is not None and not isinstance(monitor, str):
|
||||
raise ValueError("Parameter `monitor` should be of 'str' type.")
|
||||
@ -119,6 +120,19 @@ class CheckpointCallback(Callback):
|
||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
|
||||
synchronize_mkdir(self.timestamp_path)
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
if self.monitor is None:
|
||||
if trainer.monitor is not None:
|
||||
self.monitor = trainer.monitor
|
||||
self.larger_better = trainer.larger_better
|
||||
elif self.save_topk is not None:
|
||||
raise RuntimeError("You are using `topk` mode, but you have not set the `monitor` value either in this"
|
||||
"callback or in trainer.")
|
||||
else:
|
||||
self.monitor = None
|
||||
if self.save_topk is not None and trainer.evaluator is None:
|
||||
raise RuntimeError("You are using `topk` mode, but there is no `evaluator` in trainer.")
|
||||
|
||||
def on_validate_end(self, trainer, validate_res):
|
||||
self._save_topk(trainer, validate_res)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user