[bugfix] 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 (#415)

* 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行
Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。

* jittor single driver 支持 cpu 和 gpu 的切换
This commit is contained in:
Letian Li 2022-05-25 07:07:37 +01:00 committed by GitHub
parent da0b747b30
commit 5425095cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 5 deletions

View File

@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor import jittor as jt
__all__ = [ __all__ = [
"JittorSingleDriver", "JittorSingleDriver",
@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver):
def setup(self): def setup(self):
""" """
使用单个 GPU jittor 底层自动实现调配无需额外操作 支持 cpu gpu 的切换
""" """
pass if self.model_device in ["cpu", None]:
jt.flags.use_cuda = 0 # 使用 cpu
else:
jt.flags.use_cuda = 1 # 使用 cuda

View File

@ -225,7 +225,7 @@ if __name__ == "__main__":
device=[0,1,2,3,4], device=[0,1,2,3,4],
optimizers=optimizer, optimizers=optimizer,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
validate_dataloaders=val_dataloader, evaluate_dataloaders=val_dataloader,
validate_every=-1, validate_every=-1,
input_mapping=None, input_mapping=None,
output_mapping=None, output_mapping=None,

View File

@ -69,7 +69,8 @@ class TrainJittorConfig:
shuffle: bool = True shuffle: bool = True
@pytest.mark.parametrize("driver,device", [("jittor", None)]) @pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", 1])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor @pytest.mark.jittor
def test_trainer_jittor( def test_trainer_jittor(
@ -134,4 +135,5 @@ def test_trainer_jittor(
if __name__ == "__main__": if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)]) # test_trainer_jittor("jittor", None, [RichCallback(100)])
# test_trainer_jittor("jittor", 1, [RichCallback(100)])
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 pytest.main(['test_trainer_jittor.py']) # 只运行此模块