mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
small
This commit is contained in:
parent
5d1ac72ec9
commit
8cda30c426
@ -255,12 +255,13 @@ class TorchDriver(Driver):
|
||||
logger.debug("Load model...")
|
||||
|
||||
# 3. 加载fp16的状态
|
||||
if 'grad_scaler_state_dict' in states:
|
||||
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
self.auto_cast = torch.cuda.amp.autocast
|
||||
if "grad_scaler_state_dict" in states:
|
||||
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
|
||||
if isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
|
||||
self.grad_scaler = _grad_scaler()
|
||||
self.fp16 = True
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
elif not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
|
||||
|
Loading…
Reference in New Issue
Block a user