Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
x54-729 2022-04-12 20:48:47 +00:00
commit c61f28ce8e
5 changed files with 38 additions and 15 deletions

View File

@ -71,7 +71,7 @@ class Callback:
"""
pass
def on_train_batch_begin(self, trainer, batch, indices=None):
def on_train_batch_begin(self, trainer, batch, indices):
r"""
在训练过程中开始具体的一个 batch 前会被触发

View File

@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger):
auto 表示如果检测到当前 terminal 为交互型 则使用 rich否则使用 raw
"""
self.model = model
self.marker = marker
self.driver_name = driver
if isinstance(driver, str):
self.driver_name = driver
else:
self.driver_name = driver.__class__.__name__
self.device = device
self.fp16 = fp16
self.input_mapping = input_mapping
@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger):
elif accumulation_steps < 0:
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
self.accumulation_steps = accumulation_steps
# todo 思路大概是每个driver提供一下自己的参数是啥需要对应回初始化的那个然后trainer/evalutor在初始化的时候就检测一下自己手上的参数和driver的是不是一致的不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
self.driver = choose_driver(
model=model,
driver=driver,
@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger):
def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \
"`Callback` with the same callback time."
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}."
return fn
return wrapper
@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger):
def data_device(self):
return self.driver.data_device
@property
def model(self):
# 返回 driver 中的 model注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`
return self.driver.model

View File

@ -44,15 +44,11 @@ __all__ = [
]
def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字
:param fn: 需要查询的函数
:return: 一个列表其中的元素则是查询函数的参数的字符串名字
"""
return list(inspect.signature(fn).parameters)

View File

@ -1,7 +1,7 @@
import pytest
from functools import reduce
from fastNLP.core.callbacks.callback_events import Filter
from fastNLP.core.callbacks.callback_events import Events, Filter
class TestFilter:

View File

@ -0,0 +1,25 @@
import pytest
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from tests.helpers.utils import magic_argv_env_context
@magic_argv_env_context
def test_trainer_torch_without_evaluator():
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10))
def fn1(trainer):
pass
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn2(trainer, batch, indices):
pass
with pytest.raises(AssertionError):
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
def fn3(trainer, batch):
pass