mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修改了 trainer.load_model/load 只有单卡加载的bug
This commit is contained in:
parent
1ecbdc7446
commit
711bbf469c
@ -576,7 +576,7 @@ class Trainer(TrainerEventTrigger):
|
||||
if model_load_fn is not None:
|
||||
if not callable(model_load_fn):
|
||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
|
||||
rank_zero_call(model_load_fn)(folder)
|
||||
model_load_fn(folder)
|
||||
else:
|
||||
if isinstance(folder, str):
|
||||
folder = Path(folder)
|
||||
@ -653,7 +653,7 @@ class Trainer(TrainerEventTrigger):
|
||||
if model_load_fn is not None:
|
||||
if not callable(model_load_fn):
|
||||
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
|
||||
rank_zero_call(model_load_fn)(folder)
|
||||
model_load_fn(folder)
|
||||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
|
||||
else:
|
||||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user