1.修复数据处理的时候,多进程print会报错的问题;2.Trainer/Evaluator增加check_dataloader_legality

This commit is contained in:
yh 2022-05-23 19:43:29 +08:00
parent cfe389631a
commit ff50437702
15 changed files with 158 additions and 82 deletions

View File

@ -111,6 +111,7 @@ class Evaluator:
分布式进行设置如果为 ``True``将使得每个进程上的 ``dataloader`` 自动使用不同数据所有进程的数据并集是整个数据集
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法默认为 ``True``
"""
@ -134,6 +135,8 @@ class Evaluator:
self.device = device
self.verbose = verbose
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
if evaluate_batch_step_fn is not None:
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
self.evaluate_batch_step_fn = evaluate_batch_step_fn
@ -141,10 +144,23 @@ class Evaluator:
self.input_mapping = input_mapping
self.output_mapping = output_mapping
# check dataloader
if not isinstance(dataloaders, dict):
if kwargs.get('check_dataloader_legality', True):
try:
self.driver.check_dataloader_legality(dataloader=dataloaders)
except TypeError as e:
logger.error("`dataloaders` is invalid.")
raise e
dataloaders = {None: dataloaders}
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
else:
if kwargs.get('check_dataloader_legality', True):
for key, dataloader in dataloaders.items():
try:
self.driver.check_dataloader_legality(dataloader=dataloader)
except TypeError as e:
logger.error(f"The dataloader named:{key} is invalid.")
raise e
self.driver.setup()
self.driver.barrier()
@ -333,7 +349,7 @@ class Evaluator:
@evaluate_batch_loop.setter
def evaluate_batch_loop(self, loop: Loop):
if self.evaluate_batch_step_fn is not None:
if getattr(self, 'evaluate_step_fn', None) is not None:
logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
"when the `evaluate_batch_loop` is also customized.")
self._evaluate_batch_loop = loop

View File

@ -304,6 +304,7 @@ class Trainer(TrainerEventTrigger):
* *train_output_mapping* -- output_mapping 一致但是只用于 ``Trainer`` output_mapping 互斥
* *evaluate_input_mapping* -- input_mapping 一致但是只用于 ``Evaluator`` input_mapping 互斥
* *evaluate_output_mapping* -- output_mapping 一致但是只用于 ``Evaluator`` output_mapping 互斥
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法默认为 ``True``
.. note::
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证
@ -463,6 +464,14 @@ class Trainer(TrainerEventTrigger):
self.driver.setup()
self.driver.barrier()
# check train_dataloader
if kwargs.get('check_dataloader_legality', True):
try:
self.driver.check_dataloader_legality(dataloader=train_dataloader)
except TypeError as e:
logger.error("`train_dataloader` is invalid.")
raise e
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
if use_dist_sampler:
_dist_sampler = "dist"
@ -482,7 +491,8 @@ class Trainer(TrainerEventTrigger):
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar)
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")

View File

@ -10,8 +10,7 @@ __all__ = [
"prepare_dataloader"
]
from .mix_dataloader import MixDataLoader
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .prepare_dataloader import prepare_dataloader

View File

@ -1,6 +1,8 @@
__all__ = [
"TorchDataLoader",
"prepare_torch_dataloader"
"prepare_torch_dataloader",
"MixDataLoader"
]
from .fdl import TorchDataLoader, prepare_torch_dataloader
from .mix_dataloader import MixDataLoader

View File

@ -168,6 +168,7 @@ from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
from ..log import logger
from fastNLP.core.utils.dummy_class import DummyClass
progress_bars = {
@ -231,7 +232,8 @@ def _multi_proc(ds, _apply_field, func, counter, queue):
"""
idx = -1
import contextlib
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁
null = DummyClass()
with contextlib.redirect_stdout(null): # 避免打印触发 rich 的锁
logger.set_stdout(stdout='raw')
results = []
try:
@ -597,7 +599,8 @@ class DataSet:
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_desc: 进度条的描述字符默认为 ``Processing``
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
@ -631,8 +634,14 @@ class DataSet:
:param field_name: 传入func的是哪个field
:param func: 参数是 ``DataSet`` 中的 ``Instance`` 返回值是一个字典key 是field 的名字value 是对应的结果
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field` 默认为 True
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_desc: 当显示 progress_bar 显示当前正在处理的进度条描述字符
:return Dict[str:Field]: 返回一个字典
"""
@ -672,7 +681,13 @@ class DataSet:
progress_bar: str = 'rich', _apply_field: str = None,
progress_desc: str = 'Main') -> list:
"""
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param func: 用户自定义处理函数参数是 ``DataSet`` 中的 ``Instance``
:param _apply_field: 需要传进去func的数据集的field_name
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
@ -744,7 +759,13 @@ class DataSet:
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` 默认为 True
:param func: 参数是 ``DataSet`` 中的 ``Instance`` 返回值是一个字典key 是field 的名字value 是对应的结果
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_desc: progress_bar 不为 None 可以显示当前正在处理的进度条名称
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:return Dict[str:Field]: 返回一个字典
@ -789,7 +810,13 @@ class DataSet:
:param func: 参数是 ``DataSet`` 中的 ``Instance`` 返回值是一个字典key 是field 的名字value 是对应的结果
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中如果名称与已有的field相同则覆
盖之前的field如果为None则不创建新的field
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_desc: progress bar 显示的值默认为空
"""

View File

@ -175,6 +175,14 @@ class Driver(ABC):
raise NotImplementedError(
"Each specific driver should implemented its own `_check_optimizer_legality` function.")
def check_dataloader_legality(self, dataloader):
"""
检测 DataLoader 是否合法如果不合法 raise TypeError
:param dataloder:
:return:
"""
def set_optimizers(self, optimizers=None):
r"""
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上

View File

@ -13,6 +13,7 @@ if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import Module
from jittor.optim import Optimizer
from jittor.dataset import Dataset
_reduces = {
'max': jt.max,
@ -52,21 +53,11 @@ class JittorDriver(Driver):
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
def check_dataloader_legality(self, dataloader):
# 在fastnlp中实现了JittorDataLoader
# TODO: 是否允许传入Dataset
if is_train:
if not isinstance(dataloader, JittorDataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
else:
for each_dataloader in dataloader.values():
if not isinstance(each_dataloader, JittorDataLoader):
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' "
f"type, not {type(each_dataloader)}.")
if not isinstance(dataloader, Dataset):
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`")
@staticmethod
def _check_optimizer_legality(optimizers):

View File

@ -94,29 +94,9 @@ class PaddleDriver(Driver):
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.")
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
if isinstance(dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if dataloader.batch_sampler is None and dataloader.batch_size is None:
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
else:
for each_dataloader in dataloader.values():
if not isinstance(each_dataloader, DataLoader):
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'paddle.io.DataLoader' "
f"type, not {type(each_dataloader)}.")
if isinstance(each_dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None:
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
f"`batch_sampler` and `batch_size` should be set.")
def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
@staticmethod
def _check_optimizer_legality(optimizers):

View File

@ -91,26 +91,9 @@ class TorchDriver(Driver):
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.")
# todo 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
if isinstance(dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
else:
for each_dataloader in dataloader.values():
if not isinstance(each_dataloader, DataLoader):
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'DataLoader' "
f"type, not {type(each_dataloader)}.")
if isinstance(each_dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
@staticmethod
def _check_optimizer_legality(optimizers):

View File

@ -3,3 +3,9 @@ __all__ = []
class DummyClass:
def __init__(self, *args, **kwargs):
pass
def __getattr__(self, item):
return lambda *args, **kwargs: ...
def __call__(self, *args, **kwargs):
pass

View File

@ -4,7 +4,8 @@ __all__ = [
import uuid
import sys
from ...envs.imports import _module_available, _compare_version
from ...envs.utils import _module_available, _compare_version, _get_version
from ...envs import get_global_rank
from .utils import is_notebook
from ..log import logger
@ -82,8 +83,10 @@ class TqdmProgress(metaclass=Singleton):
:param kwargs:
:return:
"""
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \
f"To use tqdm, tqdm>=4.57 is needed."
if not _module_available('tqdm'):
raise ModuleNotFoundError("Package tqdm is not installed.")
elif not _compare_version('tqdm', operator.ge, '4.57'):
raise RuntimeError(f"Package tqdm>=4.57 is needed, instead of {_get_version('tqdm')}.")
from .rich_progress import f_rich_progress
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop."

View File

@ -26,6 +26,25 @@ def _module_available(module_path: str) -> bool:
return False
def _get_version(package, use_base_version: bool = False):
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
if hasattr(pkg, "__version__"):
pkg_version = Version(pkg.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution(package).version)
except TypeError:
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return pkg_version
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.

View File

@ -231,7 +231,13 @@ class DataBundle:
盖之前的field如果为None则不创建新的field
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时如果为True则直接忽略该DataSet;
如果为False则报错
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} 就直接跳过这个 dataset
:param progress_desc: 当显示 progress 可以显示当前正在处理的名称
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
@ -260,10 +266,16 @@ class DataBundle:
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` 返回值是一个字典key 是field 的名字value 是对应的结果
:param str field_name: 传入func的是哪个field
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field` 默认为 True
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时如果为True则直接忽略该DataSet;
如果为False则报错
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_desc: 当显示 progress_bar 可以显示 ``progress`` 的名称
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典第一层的 key dataset 的名字第二层的 key field 的名字
@ -292,8 +304,14 @@ class DataBundle:
:param callable func: input是instance中名为 `field_name` 的field的内容
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中如果名称与已有的field相同则覆
盖之前的field如果为None则不创建新的field
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_desc: 当显示 progress bar 可以显示当前正在处理的名称
"""
@ -316,8 +334,14 @@ class DataBundle:
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` 返回值是一个字典key 是field 的名字value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` 默认为 True
:param num_proc: 进程的数量请注意由于python语言的特性多少进程就会导致多少倍内存的增长
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param num_proc: 使用进程的数量
.. note::
由于 ``python`` 语言的特性设置该参数后会导致相应倍数的内存增长这可能会对您程序的执行带来一定的影响另外使用多进程时
``func`` 函数中的打印将不会输出
:param progress_bar: 显示 progress_bar 的方式支持 `["rich", "tqdm", None]`
:param progress_desc: 当显示 progress_bar 可以显示当前正在处理的名称
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典第一层的 key dataset 的名字第二层的 key field 的名字
@ -382,4 +406,3 @@ class DataBundle:
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str

View File

@ -4,6 +4,7 @@ import pytest
import numpy as np
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
from fastNLP import logger
class TestDataSetInit:
@ -379,6 +380,14 @@ class TestDataSetMethods:
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0)
def test_apply_more_proc(self):
def func(x):
print("x")
logger.info("demo")
return len(x)
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
data.apply_field(func, field_name='x', new_field_name='len_x', num_proc=2)
class TestFieldArrayInit:
"""