mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
[bugfix] 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 (#415)
* 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。 * jittor single driver 支持 cpu 和 gpu 的切换
This commit is contained in:
parent
da0b747b30
commit
5425095cac
@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
|||||||
from fastNLP.core.log import logger
|
from fastNLP.core.log import logger
|
||||||
|
|
||||||
if _NEED_IMPORT_JITTOR:
|
if _NEED_IMPORT_JITTOR:
|
||||||
import jittor
|
import jittor as jt
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"JittorSingleDriver",
|
"JittorSingleDriver",
|
||||||
@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
"""
|
"""
|
||||||
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
|
支持 cpu 和 gpu 的切换
|
||||||
"""
|
"""
|
||||||
pass
|
if self.model_device in ["cpu", None]:
|
||||||
|
jt.flags.use_cuda = 0 # 使用 cpu
|
||||||
|
else:
|
||||||
|
jt.flags.use_cuda = 1 # 使用 cuda
|
||||||
|
@ -225,7 +225,7 @@ if __name__ == "__main__":
|
|||||||
device=[0,1,2,3,4],
|
device=[0,1,2,3,4],
|
||||||
optimizers=optimizer,
|
optimizers=optimizer,
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
validate_dataloaders=val_dataloader,
|
evaluate_dataloaders=val_dataloader,
|
||||||
validate_every=-1,
|
validate_every=-1,
|
||||||
input_mapping=None,
|
input_mapping=None,
|
||||||
output_mapping=None,
|
output_mapping=None,
|
||||||
|
@ -69,7 +69,8 @@ class TrainJittorConfig:
|
|||||||
shuffle: bool = True
|
shuffle: bool = True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("driver,device", [("jittor", None)])
|
@pytest.mark.parametrize("driver", ["jittor"])
|
||||||
|
@pytest.mark.parametrize("device", ["cpu", 1])
|
||||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
|
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
|
||||||
@pytest.mark.jittor
|
@pytest.mark.jittor
|
||||||
def test_trainer_jittor(
|
def test_trainer_jittor(
|
||||||
@ -134,4 +135,5 @@ def test_trainer_jittor(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_trainer_jittor("jittor", None, [RichCallback(100)])
|
# test_trainer_jittor("jittor", None, [RichCallback(100)])
|
||||||
|
# test_trainer_jittor("jittor", 1, [RichCallback(100)])
|
||||||
pytest.main(['test_trainer_jittor.py']) # 只运行此模块
|
pytest.main(['test_trainer_jittor.py']) # 只运行此模块
|
||||||
|
Loading…
Reference in New Issue
Block a user