mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
[bugfix] 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 (#415)
* 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。 * jittor single driver 支持 cpu 和 gpu 的切换
This commit is contained in:
parent
da0b747b30
commit
5425095cac
@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
import jittor as jt
|
||||
|
||||
__all__ = [
|
||||
"JittorSingleDriver",
|
||||
@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver):
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
|
||||
支持 cpu 和 gpu 的切换
|
||||
"""
|
||||
pass
|
||||
if self.model_device in ["cpu", None]:
|
||||
jt.flags.use_cuda = 0 # 使用 cpu
|
||||
else:
|
||||
jt.flags.use_cuda = 1 # 使用 cuda
|
||||
|
@ -225,7 +225,7 @@ if __name__ == "__main__":
|
||||
device=[0,1,2,3,4],
|
||||
optimizers=optimizer,
|
||||
train_dataloader=train_dataloader,
|
||||
validate_dataloaders=val_dataloader,
|
||||
evaluate_dataloaders=val_dataloader,
|
||||
validate_every=-1,
|
||||
input_mapping=None,
|
||||
output_mapping=None,
|
||||
|
@ -69,7 +69,8 @@ class TrainJittorConfig:
|
||||
shuffle: bool = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("driver,device", [("jittor", None)])
|
||||
@pytest.mark.parametrize("driver", ["jittor"])
|
||||
@pytest.mark.parametrize("device", ["cpu", 1])
|
||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
|
||||
@pytest.mark.jittor
|
||||
def test_trainer_jittor(
|
||||
@ -134,4 +135,5 @@ def test_trainer_jittor(
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_trainer_jittor("jittor", None, [RichCallback(100)])
|
||||
# test_trainer_jittor("jittor", 1, [RichCallback(100)])
|
||||
pytest.main(['test_trainer_jittor.py']) # 只运行此模块
|
||||
|
Loading…
Reference in New Issue
Block a user