This commit is contained in:
x54-729 2022-05-10 07:54:17 +00:00
parent 4885b8237c
commit 4fbca22267

View File

@ -23,9 +23,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
elif driver in {"jittor"}:
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
return initialize_jittor_driver(driver, device, model, **kwargs)
elif driver in {"paddle", "fleet"}:
elif driver in {"paddle"}:
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
return initialize_paddle_driver(driver, device, model, **kwargs)
else:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', "
"'jittor', 'paddle', 'fleet'].")
"'jittor', 'paddle'].")