mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 19:28:17 +08:00
加入了callback events 的测试
This commit is contained in:
parent
30af3b032f
commit
b88a15dabb
@ -4,8 +4,8 @@ from functools import reduce
|
|||||||
from fastNLP.core.callbacks.callback_events import Events, Filter
|
from fastNLP.core.callbacks.callback_events import Events, Filter
|
||||||
|
|
||||||
|
|
||||||
class TestFilter:
|
|
||||||
|
|
||||||
|
class TestFilter:
|
||||||
def test_params_check(self):
|
def test_params_check(self):
|
||||||
# 顺利通过
|
# 顺利通过
|
||||||
_filter1 = Filter(every=10)
|
_filter1 = Filter(every=10)
|
||||||
@ -80,35 +80,6 @@ class TestFilter:
|
|||||||
_res.append(cu_res)
|
_res.append(cu_res)
|
||||||
assert _res == [9]
|
assert _res == [9]
|
||||||
|
|
||||||
def test_filter_fn(self):
|
|
||||||
from torch.optim import SGD
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from fastNLP.core.controllers.trainer import Trainer
|
|
||||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
|
||||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
|
|
||||||
|
|
||||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
|
|
||||||
optimizer = SGD(model.parameters(), lr=0.0001)
|
|
||||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
|
|
||||||
dataloader = DataLoader(dataset=dataset, batch_size=4)
|
|
||||||
|
|
||||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
|
|
||||||
def filter_fn(filter, trainer):
|
|
||||||
if trainer.__heihei_test__ == 10:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@Filter(filter_fn=filter_fn)
|
|
||||||
def _fn(trainer, data):
|
|
||||||
return data
|
|
||||||
|
|
||||||
_res = []
|
|
||||||
for i in range(100):
|
|
||||||
trainer.__heihei_test__ = i
|
|
||||||
cu_res = _fn(trainer, i)
|
|
||||||
if cu_res is not None:
|
|
||||||
_res.append(cu_res)
|
|
||||||
assert _res == [10]
|
|
||||||
|
|
||||||
def test_extract_filter_from_fn(self):
|
def test_extract_filter_from_fn(self):
|
||||||
@Filter(every=10)
|
@Filter(every=10)
|
||||||
@ -155,3 +126,119 @@ class TestFilter:
|
|||||||
assert _res == [w - 1 for w in range(60, 101, 10)]
|
assert _res == [w - 1 for w in range(60, 101, 10)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.torch
|
||||||
|
def test_filter_fn_torch():
|
||||||
|
from torch.optim import SGD
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from fastNLP.core.controllers.trainer import Trainer
|
||||||
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||||
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
|
||||||
|
|
||||||
|
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
|
||||||
|
optimizer = SGD(model.parameters(), lr=0.0001)
|
||||||
|
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
|
||||||
|
dataloader = DataLoader(dataset=dataset, batch_size=4)
|
||||||
|
|
||||||
|
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
|
||||||
|
def filter_fn(filter, trainer):
|
||||||
|
if trainer.__heihei_test__ == 10:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@Filter(filter_fn=filter_fn)
|
||||||
|
def _fn(trainer, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_res = []
|
||||||
|
for i in range(100):
|
||||||
|
trainer.__heihei_test__ = i
|
||||||
|
cu_res = _fn(trainer, i)
|
||||||
|
if cu_res is not None:
|
||||||
|
_res.append(cu_res)
|
||||||
|
assert _res == [10]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackEvents:
|
||||||
|
def test_every(self):
|
||||||
|
|
||||||
|
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
|
||||||
|
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1;
|
||||||
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
|
||||||
|
def _fn(data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_res = []
|
||||||
|
for i in range(100):
|
||||||
|
cu_res = _fn(i)
|
||||||
|
if cu_res is not None:
|
||||||
|
_res.append(cu_res)
|
||||||
|
assert _res == list(range(100))
|
||||||
|
|
||||||
|
event_state = Events.on_train_begin(every=10)
|
||||||
|
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
|
||||||
|
def _fn(data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_res = []
|
||||||
|
for i in range(100):
|
||||||
|
cu_res = _fn(i)
|
||||||
|
if cu_res is not None:
|
||||||
|
_res.append(cu_res)
|
||||||
|
assert _res == [w - 1 for w in range(10, 101, 10)]
|
||||||
|
|
||||||
|
def test_once(self):
|
||||||
|
event_state = Events.on_train_begin(once=10)
|
||||||
|
|
||||||
|
@Filter(once=event_state.once)
|
||||||
|
def _fn(data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_res = []
|
||||||
|
for i in range(100):
|
||||||
|
cu_res = _fn(i)
|
||||||
|
if cu_res is not None:
|
||||||
|
_res.append(cu_res)
|
||||||
|
assert _res == [9]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.torch
|
||||||
|
def test_callback_events_torch():
|
||||||
|
from torch.optim import SGD
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from fastNLP.core.controllers.trainer import Trainer
|
||||||
|
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||||
|
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
|
||||||
|
|
||||||
|
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
|
||||||
|
optimizer = SGD(model.parameters(), lr=0.0001)
|
||||||
|
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
|
||||||
|
dataloader = DataLoader(dataset=dataset, batch_size=4)
|
||||||
|
|
||||||
|
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
|
||||||
|
def filter_fn(filter, trainer):
|
||||||
|
if trainer.__heihei_test__ == 10:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
event_state = Events.on_train_begin(filter_fn=filter_fn)
|
||||||
|
|
||||||
|
@Filter(filter_fn=event_state.filter_fn)
|
||||||
|
def _fn(trainer, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_res = []
|
||||||
|
for i in range(100):
|
||||||
|
trainer.__heihei_test__ = i
|
||||||
|
cu_res = _fn(trainer, i)
|
||||||
|
if cu_res is not None:
|
||||||
|
_res.append(cu_res)
|
||||||
|
assert _res == [10]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -221,124 +221,6 @@ def test_trainer_event_trigger_2(
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
|
|
||||||
@pytest.mark.torch
|
|
||||||
@magic_argv_env_context
|
|
||||||
def test_trainer_event_trigger_3(
|
|
||||||
model_and_optimizers: TrainerParameters,
|
|
||||||
driver,
|
|
||||||
device,
|
|
||||||
n_epochs=2,
|
|
||||||
):
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_after_trainer_initialized)
|
|
||||||
def on_after_trainer_initialized(trainer, driver):
|
|
||||||
print("on_after_trainer_initialized")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_sanity_check_begin)
|
|
||||||
def on_sanity_check_begin(trainer):
|
|
||||||
print("on_sanity_check_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_sanity_check_end)
|
|
||||||
def on_sanity_check_end(trainer, sanity_check_res):
|
|
||||||
print("on_sanity_check_end")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_begin)
|
|
||||||
def on_train_begin(trainer):
|
|
||||||
print("on_train_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_end)
|
|
||||||
def on_train_end(trainer):
|
|
||||||
print("on_train_end")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_epoch_begin)
|
|
||||||
def on_train_epoch_begin(trainer):
|
|
||||||
if trainer.cur_epoch_idx >= 1:
|
|
||||||
# 触发 on_exception;
|
|
||||||
raise Exception
|
|
||||||
print("on_train_epoch_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_epoch_end)
|
|
||||||
def on_train_epoch_end(trainer):
|
|
||||||
print("on_train_epoch_end")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_fetch_data_begin)
|
|
||||||
def on_fetch_data_begin(trainer):
|
|
||||||
print("on_fetch_data_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_fetch_data_end)
|
|
||||||
def on_fetch_data_end(trainer):
|
|
||||||
print("on_fetch_data_end")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_batch_begin)
|
|
||||||
def on_train_batch_begin(trainer, batch, indices=None):
|
|
||||||
print("on_train_batch_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_train_batch_end)
|
|
||||||
def on_train_batch_end(trainer):
|
|
||||||
print("on_train_batch_end")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_exception)
|
|
||||||
def on_exception(trainer, exception):
|
|
||||||
print("on_exception")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_before_backward)
|
|
||||||
def on_before_backward(trainer, outputs):
|
|
||||||
print("on_before_backward")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_after_backward)
|
|
||||||
def on_after_backward(trainer):
|
|
||||||
print("on_after_backward")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_before_optimizers_step)
|
|
||||||
def on_before_optimizers_step(trainer, optimizers):
|
|
||||||
print("on_before_optimizers_step")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_after_optimizers_step)
|
|
||||||
def on_after_optimizers_step(trainer, optimizers):
|
|
||||||
print("on_after_optimizers_step")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_before_zero_grad)
|
|
||||||
def on_before_zero_grad(trainer, optimizers):
|
|
||||||
print("on_before_zero_grad")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_after_zero_grad)
|
|
||||||
def on_after_zero_grad(trainer, optimizers):
|
|
||||||
print("on_after_zero_grad")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_evaluate_begin)
|
|
||||||
def on_evaluate_begin(trainer):
|
|
||||||
print("on_evaluate_begin")
|
|
||||||
|
|
||||||
@Trainer.on(Events.on_evaluate_end)
|
|
||||||
def on_evaluate_end(trainer, results):
|
|
||||||
print("on_evaluate_end")
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
with Capturing() as output:
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model_and_optimizers.model,
|
|
||||||
driver=driver,
|
|
||||||
device=device,
|
|
||||||
optimizers=model_and_optimizers.optimizers,
|
|
||||||
train_dataloader=model_and_optimizers.train_dataloader,
|
|
||||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
|
|
||||||
input_mapping=model_and_optimizers.input_mapping,
|
|
||||||
output_mapping=model_and_optimizers.output_mapping,
|
|
||||||
metrics=model_and_optimizers.metrics,
|
|
||||||
|
|
||||||
n_epochs=n_epochs,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.run()
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
for name, member in Events.__members__.items():
|
|
||||||
assert member.value in output[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user