增加pytorch版本限制为1.6以上

This commit is contained in:
yh 2022-05-17 22:39:19 +08:00
parent 6d2dca421d
commit 422510285e
2 changed files with 2 additions and 2 deletions

View File

@ -55,7 +55,7 @@ class TorchDriver(Driver):
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里;
self.fp16 = fp16
if parse_version(torch.__version__) < parse_version('1.6'):
raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.")
raise RuntimeError(f"Pytorch({torch.__version__}) need to be older than 1.6.")
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

View File

@ -160,7 +160,7 @@ def _build_fp16_env(dummy=False):
GradScaler = DummyGradScaler
else:
if not torch.cuda.is_available():
raise RuntimeError("No cuda")
raise RuntimeError("Pytorch is not installed in gpu version, please use device='cpu'.")
if torch.cuda.get_device_capability(0)[0] < 7:
logger.rank_zero_warning(
"NOTE: your device does NOT support faster training with fp16, "