修复测试的一些bug

This commit is contained in:
x54-729 2022-09-20 13:31:33 +08:00
parent 82fe167ea3
commit babf4b2f19
3 changed files with 4 additions and 8 deletions

View File

@ -75,7 +75,7 @@ def model_and_optimizers(request):
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,

View File

@ -11,12 +11,8 @@ from ...helpers.utils import Capturing
def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
if 'oneflow' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
elif 'oneflow' in str(type(d1)):
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)

View File

@ -67,7 +67,7 @@ def model_and_optimizers(request):
return trainer_params
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_torch_without_evaluator(
@ -97,7 +97,7 @@ def test_trainer_torch_without_evaluator(
dist.destroy_process_group()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@pytest.mark.parametrize("save_on_rank0", [True, False])
@magic_argv_env_context(timeout=100)