修改了 trainer.load_model/load 只有单卡加载的bug

This commit is contained in:
YWMditto 2022-04-23 14:44:18 +08:00
parent 1ecbdc7446
commit 711bbf469c

View File

@ -576,7 +576,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
else:
if isinstance(folder, str):
folder = Path(folder)
@ -653,7 +653,7 @@ class Trainer(TrainerEventTrigger):
if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
model_load_fn(folder)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else:
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)