mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
添加了测试 函数式 callback 的test
This commit is contained in:
parent
1528107480
commit
3069017aae
@ -67,7 +67,7 @@ def model_and_optimizers():
|
||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
|
||||
@pytest.mark.torch
|
||||
@magic_argv_env_context
|
||||
def test_trainer_event_trigger(
|
||||
def test_trainer_event_trigger_1(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
driver,
|
||||
device,
|
||||
@ -101,5 +101,126 @@ def test_trainer_event_trigger(
|
||||
assert member.value in output[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7])
|
||||
@pytest.mark.torch
|
||||
@magic_argv_env_context
|
||||
def test_trainer_event_trigger_2(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
driver,
|
||||
device,
|
||||
n_epochs=2,
|
||||
):
|
||||
|
||||
@Trainer.on(Events.on_after_trainer_initialized)
|
||||
def on_after_trainer_initialized(trainer, driver):
|
||||
print("on_after_trainer_initialized")
|
||||
|
||||
@Trainer.on(Events.on_sanity_check_begin)
|
||||
def on_sanity_check_begin(trainer):
|
||||
print("on_sanity_check_begin")
|
||||
|
||||
@Trainer.on(Events.on_sanity_check_end)
|
||||
def on_sanity_check_end(trainer, sanity_check_res):
|
||||
print("on_sanity_check_end")
|
||||
|
||||
@Trainer.on(Events.on_train_begin)
|
||||
def on_train_begin(trainer):
|
||||
print("on_train_begin")
|
||||
|
||||
@Trainer.on(Events.on_train_end)
|
||||
def on_train_end(trainer):
|
||||
print("on_train_end")
|
||||
|
||||
@Trainer.on(Events.on_train_epoch_begin)
|
||||
def on_train_epoch_begin(trainer):
|
||||
if trainer.cur_epoch_idx >= 1:
|
||||
# 触发 on_exception;
|
||||
raise Exception
|
||||
print("on_train_epoch_begin")
|
||||
|
||||
@Trainer.on(Events.on_train_epoch_end)
|
||||
def on_train_epoch_end(trainer):
|
||||
print("on_train_epoch_end")
|
||||
|
||||
@Trainer.on(Events.on_fetch_data_begin)
|
||||
def on_fetch_data_begin(trainer):
|
||||
print("on_fetch_data_begin")
|
||||
|
||||
@Trainer.on(Events.on_fetch_data_end)
|
||||
def on_fetch_data_end(trainer):
|
||||
print("on_fetch_data_end")
|
||||
|
||||
@Trainer.on(Events.on_train_batch_begin)
|
||||
def on_train_batch_begin(trainer, batch, indices=None):
|
||||
print("on_train_batch_begin")
|
||||
|
||||
@Trainer.on(Events.on_train_batch_end)
|
||||
def on_train_batch_end(trainer):
|
||||
print("on_train_batch_end")
|
||||
|
||||
@Trainer.on(Events.on_exception)
|
||||
def on_exception(trainer, exception):
|
||||
print("on_exception")
|
||||
|
||||
@Trainer.on(Events.on_before_backward)
|
||||
def on_before_backward(trainer, outputs):
|
||||
print("on_before_backward")
|
||||
|
||||
@Trainer.on(Events.on_after_backward)
|
||||
def on_after_backward(trainer):
|
||||
print("on_after_backward")
|
||||
|
||||
@Trainer.on(Events.on_before_optimizers_step)
|
||||
def on_before_optimizers_step(trainer, optimizers):
|
||||
print("on_before_optimizers_step")
|
||||
|
||||
@Trainer.on(Events.on_after_optimizers_step)
|
||||
def on_after_optimizers_step(trainer, optimizers):
|
||||
print("on_after_optimizers_step")
|
||||
|
||||
@Trainer.on(Events.on_before_zero_grad)
|
||||
def on_before_zero_grad(trainer, optimizers):
|
||||
print("on_before_zero_grad")
|
||||
|
||||
@Trainer.on(Events.on_after_zero_grad)
|
||||
def on_after_zero_grad(trainer, optimizers):
|
||||
print("on_after_zero_grad")
|
||||
|
||||
@Trainer.on(Events.on_evaluate_begin)
|
||||
def on_evaluate_begin(trainer):
|
||||
print("on_evaluate_begin")
|
||||
|
||||
@Trainer.on(Events.on_evaluate_end)
|
||||
def on_evaluate_end(trainer, results):
|
||||
print("on_evaluate_end")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with Capturing() as output:
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver=driver,
|
||||
device=device,
|
||||
optimizers=model_and_optimizers.optimizers,
|
||||
train_dataloader=model_and_optimizers.train_dataloader,
|
||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
|
||||
input_mapping=model_and_optimizers.input_mapping,
|
||||
output_mapping=model_and_optimizers.output_mapping,
|
||||
metrics=model_and_optimizers.metrics,
|
||||
|
||||
n_epochs=n_epochs,
|
||||
)
|
||||
|
||||
trainer.run()
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
for name, member in Events.__members__.items():
|
||||
assert member.value in output[0]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user