mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
c61f28ce8e
@ -71,7 +71,7 @@ class Callback:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_train_batch_begin(self, trainer, batch, indices=None):
|
||||
def on_train_batch_begin(self, trainer, batch, indices):
|
||||
r"""
|
||||
在训练过程中开始具体的一个 batch 前会被触发;
|
||||
|
||||
|
@ -130,9 +130,12 @@ class Trainer(TrainerEventTrigger):
|
||||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。
|
||||
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.marker = marker
|
||||
self.driver_name = driver
|
||||
if isinstance(driver, str):
|
||||
self.driver_name = driver
|
||||
else:
|
||||
self.driver_name = driver.__class__.__name__
|
||||
self.device = device
|
||||
self.fp16 = fp16
|
||||
self.input_mapping = input_mapping
|
||||
@ -157,6 +160,8 @@ class Trainer(TrainerEventTrigger):
|
||||
elif accumulation_steps < 0:
|
||||
raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
|
||||
self.accumulation_steps = accumulation_steps
|
||||
|
||||
# todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
|
||||
self.driver = choose_driver(
|
||||
model=model,
|
||||
driver=driver,
|
||||
@ -403,9 +408,10 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
def wrapper(fn: Callable) -> Callable:
|
||||
cls._custom_callbacks[marker].append((event, fn))
|
||||
assert check_fn_not_empty_params(fn, len(get_fn_arg_names(getattr(Callback, event.value))) - 1), "Your " \
|
||||
"callback fn's allowed parameters seem not to be equal with the origin callback fn in class " \
|
||||
"`Callback` with the same callback time."
|
||||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
|
||||
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \
|
||||
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\
|
||||
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}."
|
||||
return fn
|
||||
|
||||
return wrapper
|
||||
@ -807,10 +813,6 @@ class Trainer(TrainerEventTrigger):
|
||||
def data_device(self):
|
||||
return self.driver.data_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`;
|
||||
return self.driver.model
|
||||
|
||||
|
||||
|
||||
|
@ -44,15 +44,11 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_fn_arg_names(fn: Callable) -> List[str]:
|
||||
r"""
|
||||
返回一个函数的所有参数的名字;
|
||||
|
||||
:param fn: 需要查询的函数;
|
||||
|
||||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字;
|
||||
"""
|
||||
return list(inspect.signature(fn).parameters)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from functools import reduce
|
||||
|
||||
from fastNLP.core.callbacks.callback_events import Filter
|
||||
from fastNLP.core.callbacks.callback_events import Events, Filter
|
||||
|
||||
|
||||
class TestFilter:
|
||||
|
25
tests/core/controllers/test_trainer_other_things.py
Normal file
25
tests/core/controllers/test_trainer_other_things.py
Normal file
@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.core.callbacks import Events
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_trainer_torch_without_evaluator():
|
||||
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10))
|
||||
def fn1(trainer):
|
||||
pass
|
||||
|
||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
|
||||
def fn2(trainer, batch, indices):
|
||||
pass
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10))
|
||||
def fn3(trainer, batch):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user