mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修复测试
This commit is contained in:
parent
40c0a712dd
commit
4781178a5a
@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]])
|
||||
@pytest.mark.parametrize("fp16", [True, False])
|
||||
@pytest.mark.parametrize("accumulation_steps", [1, 3])
|
||||
@magic_argv_env_context
|
||||
@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
driver,
|
||||
device,
|
||||
callbacks,
|
||||
fp16,
|
||||
accumulation_steps,
|
||||
n_epochs=6,
|
||||
):
|
||||
|
||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver=driver,
|
||||
|
@ -77,3 +77,4 @@ def check_replace_sampler(driver):
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user