mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
修复测试的一些bug
This commit is contained in:
parent
82fe167ea3
commit
babf4b2f19
@ -75,7 +75,7 @@ def model_and_optimizers(request):
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
|
||||
@magic_argv_env_context(timeout=100)
|
||||
def test_model_checkpoint_callback_1(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
|
@ -11,12 +11,8 @@ from ...helpers.utils import Capturing
|
||||
def _assert_equal(d1, d2):
|
||||
try:
|
||||
if 'torch' in str(type(d1)):
|
||||
if 'float64' in str(d2.dtype):
|
||||
print(d2.dtype)
|
||||
assert (d1 == d2).all().item()
|
||||
if 'oneflow' in str(type(d1)):
|
||||
if 'float64' in str(d2.dtype):
|
||||
print(d2.dtype)
|
||||
elif 'oneflow' in str(type(d1)):
|
||||
assert (d1 == d2).all().item()
|
||||
else:
|
||||
assert all(d1 == d2)
|
||||
|
@ -67,7 +67,7 @@ def model_and_optimizers(request):
|
||||
|
||||
return trainer_params
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.torch
|
||||
@magic_argv_env_context
|
||||
def test_trainer_torch_without_evaluator(
|
||||
@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator(
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("save_on_rank0", [True, False])
|
||||
@magic_argv_env_context(timeout=100)
|
||||
|
Loading…
Reference in New Issue
Block a user