From 3069017aaef12cb1afd01c5691648df54cd35faa Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 3 May 2022 16:39:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=20=E5=87=BD=E6=95=B0=E5=BC=8F=20callback=20=E7=9A=84test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/test_trainer_event_trigger.py | 123 +++++++++++++++++- 1 file changed, 122 insertions(+), 1 deletion(-) diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index bcd89614..84752287 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -67,7 +67,7 @@ def model_and_optimizers(): @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) @pytest.mark.torch @magic_argv_env_context -def test_trainer_event_trigger( +def test_trainer_event_trigger_1( model_and_optimizers: TrainerParameters, driver, device, @@ -101,5 +101,126 @@ def test_trainer_event_trigger( assert member.value in output[0] +@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_2( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + + @Trainer.on(Events.on_after_trainer_initialized) + def on_after_trainer_initialized(trainer, driver): + print("on_after_trainer_initialized") + + @Trainer.on(Events.on_sanity_check_begin) + def on_sanity_check_begin(trainer): + print("on_sanity_check_begin") + + @Trainer.on(Events.on_sanity_check_end) + def on_sanity_check_end(trainer, sanity_check_res): + print("on_sanity_check_end") + + @Trainer.on(Events.on_train_begin) + def on_train_begin(trainer): + print("on_train_begin") + + @Trainer.on(Events.on_train_end) + def on_train_end(trainer): + print("on_train_end") + + @Trainer.on(Events.on_train_epoch_begin) + def on_train_epoch_begin(trainer): + if trainer.cur_epoch_idx >= 1: + # 触发 on_exception; + raise Exception + print("on_train_epoch_begin") + + @Trainer.on(Events.on_train_epoch_end) + def on_train_epoch_end(trainer): + print("on_train_epoch_end") + + @Trainer.on(Events.on_fetch_data_begin) + def on_fetch_data_begin(trainer): + print("on_fetch_data_begin") + + @Trainer.on(Events.on_fetch_data_end) + def on_fetch_data_end(trainer): + print("on_fetch_data_end") + + @Trainer.on(Events.on_train_batch_begin) + def on_train_batch_begin(trainer, batch, indices=None): + print("on_train_batch_begin") + + @Trainer.on(Events.on_train_batch_end) + def on_train_batch_end(trainer): + print("on_train_batch_end") + + @Trainer.on(Events.on_exception) + def on_exception(trainer, exception): + print("on_exception") + + @Trainer.on(Events.on_before_backward) + def on_before_backward(trainer, outputs): + print("on_before_backward") + + @Trainer.on(Events.on_after_backward) + def on_after_backward(trainer): + print("on_after_backward") + + @Trainer.on(Events.on_before_optimizers_step) + def on_before_optimizers_step(trainer, optimizers): + print("on_before_optimizers_step") + + @Trainer.on(Events.on_after_optimizers_step) + def on_after_optimizers_step(trainer, optimizers): + print("on_after_optimizers_step") + + @Trainer.on(Events.on_before_zero_grad) + def on_before_zero_grad(trainer, optimizers): + print("on_before_zero_grad") + + @Trainer.on(Events.on_after_zero_grad) + def on_after_zero_grad(trainer, optimizers): + print("on_after_zero_grad") + + @Trainer.on(Events.on_evaluate_begin) + def on_evaluate_begin(trainer): + print("on_evaluate_begin") + + @Trainer.on(Events.on_evaluate_end) + def on_evaluate_end(trainer, results): + print("on_evaluate_end") + + with pytest.raises(Exception): + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + for name, member in Events.__members__.items(): + assert member.value in output[0] + + + + +