Trainer中保存最佳模型存在bug

This commit is contained in:
yh_cc 2019-08-28 23:06:13 +08:00
parent 6201f66178
commit a46b8f129b

View File

@ -718,7 +718,7 @@ class Trainer(object):
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
elif self._load_best_model:
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict()}
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()}
self.best_dev_perf = res
self.best_dev_epoch = epoch
self.best_dev_step = step