mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
增加pytorch版本限制为1.6以上
This commit is contained in:
parent
6d2dca421d
commit
422510285e
@ -55,7 +55,7 @@ class TorchDriver(Driver):
|
||||
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里;
|
||||
self.fp16 = fp16
|
||||
if parse_version(torch.__version__) < parse_version('1.6'):
|
||||
raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.")
|
||||
raise RuntimeError(f"Pytorch({torch.__version__}) need to be older than 1.6.")
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
|
||||
self.grad_scaler = _grad_scaler()
|
||||
|
||||
|
@ -160,7 +160,7 @@ def _build_fp16_env(dummy=False):
|
||||
GradScaler = DummyGradScaler
|
||||
else:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No cuda")
|
||||
raise RuntimeError("Pytorch is not installed in gpu version, please use device='cpu'.")
|
||||
if torch.cuda.get_device_capability(0)[0] < 7:
|
||||
logger.rank_zero_warning(
|
||||
"NOTE: your device does NOT support faster training with fp16, "
|
||||
|
Loading…
Reference in New Issue
Block a user