From 711bbf469c8b5c8357990767fbbd9b0f59e6724c Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 23 Apr 2022 14:44:18 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20trainer.load=5Fmo?= =?UTF-8?q?del/load=20=E5=8F=AA=E6=9C=89=E5=8D=95=E5=8D=A1=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e4cd2817..afd5d06a 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -576,7 +576,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) else: if isinstance(folder, str): folder = Path(folder) @@ -653,7 +653,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable`.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) else: states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)