mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
4146f8f348
@ -4,8 +4,6 @@ from types import DynamicClassAttribute
|
||||
from functools import wraps
|
||||
|
||||
|
||||
import fastNLP
|
||||
|
||||
__all__ = [
|
||||
'Events',
|
||||
'EventsList',
|
||||
|
@ -16,7 +16,7 @@ SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
|
||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend
|
||||
|
||||
|
||||
def _get_backend():
|
||||
def _get_backend() -> str:
|
||||
"""
|
||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个:
|
||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是
|
||||
@ -57,7 +57,7 @@ def _get_backend():
|
||||
else:
|
||||
break
|
||||
if len(catch_backend):
|
||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.")
|
||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.")
|
||||
return catch_backend[0]
|
||||
|
||||
# 方式 (2)
|
||||
@ -66,7 +66,7 @@ def _get_backend():
|
||||
if catch_backend:
|
||||
break
|
||||
if len(catch_backend):
|
||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.")
|
||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
|
||||
return catch_backend[0]
|
||||
|
||||
return 'numpy'
|
||||
@ -80,7 +80,7 @@ class Collator:
|
||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。
|
||||
|
||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。
|
||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad
|
||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad
|
||||
的数据返回一定是 list 。
|
||||
"""
|
||||
self.unpack_batch_func = None
|
||||
@ -144,15 +144,18 @@ class Collator:
|
||||
for key in unpack_batch.keys():
|
||||
if key not in self.input_fields and key not in self.ignore_fields:
|
||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
|
||||
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto':
|
||||
self.input_fields[key]['backend'] = self.backend
|
||||
|
||||
for field_name, setting in self.input_fields.items():
|
||||
pad_fn = setting.get('pad_fn', None)
|
||||
if callable(pad_fn):
|
||||
padder = pad_fn
|
||||
else:
|
||||
backend = self.backend if setting['backend'] == 'auto' else setting['backend']
|
||||
batch_field = unpack_batch.get(field_name)
|
||||
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
|
||||
dtype=setting['dtype'], backend=setting['backend'],
|
||||
dtype=setting['dtype'], backend=backend,
|
||||
field_name=field_name)
|
||||
self.padders[field_name] = padder
|
||||
if self.batch_data_type == 'l':
|
||||
|
@ -13,7 +13,6 @@ if _NEED_IMPORT_PADDLE:
|
||||
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
@dataclass
|
||||
|
@ -100,17 +100,16 @@ def model_and_optimizers(request):
|
||||
# 测试一下普通的情况;
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
|
||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]])
|
||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_torch_with_evaluator(
|
||||
model_and_optimizers: TrainerParameters,
|
||||
driver,
|
||||
device,
|
||||
callbacks,
|
||||
evaluate_every,
|
||||
n_epochs=10,
|
||||
):
|
||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]
|
||||
trainer = Trainer(
|
||||
model=model_and_optimizers.model,
|
||||
driver=driver,
|
||||
@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
|
||||
@magic_argv_env_context
|
||||
def test_trainer_validate_every(
|
||||
@ -184,9 +183,7 @@ def test_trainer_validate_every(
|
||||
|
||||
def validate_every(trainer):
|
||||
if trainer.global_forward_batches % 10 == 0:
|
||||
print(trainer)
|
||||
print("\nfastNLP test validate every.\n")
|
||||
print(trainer.global_forward_batches)
|
||||
return True
|
||||
|
||||
trainer = Trainer(
|
||||
|
@ -30,12 +30,12 @@ def recover_logger(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
def magic_argv_env_context(fn=None, timeout=600):
|
||||
def magic_argv_env_context(fn=None, timeout=300):
|
||||
"""
|
||||
用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确;
|
||||
会丢掉 pytest 中的 arg 参数。
|
||||
|
||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒;
|
||||
:param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒;
|
||||
:return:
|
||||
"""
|
||||
# 说明是通过 @magic_argv_env_context(timeout=600) 调用;
|
||||
|
Loading…
Reference in New Issue
Block a user