mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
small
This commit is contained in:
parent
4885b8237c
commit
4fbca22267
@ -23,9 +23,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
|
||||
elif driver in {"jittor"}:
|
||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
|
||||
return initialize_jittor_driver(driver, device, model, **kwargs)
|
||||
elif driver in {"paddle", "fleet"}:
|
||||
elif driver in {"paddle"}:
|
||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
|
||||
return initialize_paddle_driver(driver, device, model, **kwargs)
|
||||
else:
|
||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', "
|
||||
"'jittor', 'paddle', 'fleet'].")
|
||||
"'jittor', 'paddle'].")
|
Loading…
Reference in New Issue
Block a user