修复测试

This commit is contained in:
yh_cc 2022-04-08 22:36:52 +08:00
parent 40c0a712dd
commit 4781178a5a
2 changed files with 2 additions and 3 deletions

View File

@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator(
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]])
@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@magic_argv_env_context
@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
model_and_optimizers: TrainerParameters,
driver,
device,
callbacks,
fp16,
accumulation_steps,
n_epochs=6,
):
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,

View File

@ -77,3 +77,4 @@ def check_replace_sampler(driver):