mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
1.修复数据处理的时候,多进程print会报错的问题;2.Trainer/Evaluator增加check_dataloader_legality
This commit is contained in:
parent
cfe389631a
commit
ff50437702
@ -111,6 +111,7 @@ class Evaluator:
|
||||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集;
|
||||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数;
|
||||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数;
|
||||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。
|
||||
|
||||
"""
|
||||
|
||||
@ -134,6 +135,8 @@ class Evaluator:
|
||||
self.device = device
|
||||
self.verbose = verbose
|
||||
|
||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
|
||||
|
||||
if evaluate_batch_step_fn is not None:
|
||||
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
|
||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn
|
||||
@ -141,10 +144,23 @@ class Evaluator:
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
# check dataloader
|
||||
if not isinstance(dataloaders, dict):
|
||||
if kwargs.get('check_dataloader_legality', True):
|
||||
try:
|
||||
self.driver.check_dataloader_legality(dataloader=dataloaders)
|
||||
except TypeError as e:
|
||||
logger.error("`dataloaders` is invalid.")
|
||||
raise e
|
||||
dataloaders = {None: dataloaders}
|
||||
|
||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn)
|
||||
else:
|
||||
if kwargs.get('check_dataloader_legality', True):
|
||||
for key, dataloader in dataloaders.items():
|
||||
try:
|
||||
self.driver.check_dataloader_legality(dataloader=dataloader)
|
||||
except TypeError as e:
|
||||
logger.error(f"The dataloader named:{key} is invalid.")
|
||||
raise e
|
||||
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
@ -333,7 +349,7 @@ class Evaluator:
|
||||
|
||||
@evaluate_batch_loop.setter
|
||||
def evaluate_batch_loop(self, loop: Loop):
|
||||
if self.evaluate_batch_step_fn is not None:
|
||||
if getattr(self, 'evaluate_step_fn', None) is not None:
|
||||
logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
|
||||
"when the `evaluate_batch_loop` is also customized.")
|
||||
self._evaluate_batch_loop = loop
|
||||
|
@ -304,6 +304,7 @@ class Trainer(TrainerEventTrigger):
|
||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。
|
||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。
|
||||
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。
|
||||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。
|
||||
|
||||
.. note::
|
||||
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证;
|
||||
@ -463,6 +464,14 @@ class Trainer(TrainerEventTrigger):
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
|
||||
# check train_dataloader
|
||||
if kwargs.get('check_dataloader_legality', True):
|
||||
try:
|
||||
self.driver.check_dataloader_legality(dataloader=train_dataloader)
|
||||
except TypeError as e:
|
||||
logger.error("`train_dataloader` is invalid.")
|
||||
raise e
|
||||
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
|
||||
if use_dist_sampler:
|
||||
_dist_sampler = "dist"
|
||||
@ -482,7 +491,8 @@ class Trainer(TrainerEventTrigger):
|
||||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
|
||||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
|
||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
|
||||
progress_bar=progress_bar)
|
||||
progress_bar=progress_bar,
|
||||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
|
||||
|
||||
if train_fn is not None and not isinstance(train_fn, str):
|
||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
|
||||
|
@ -10,8 +10,7 @@ __all__ = [
|
||||
"prepare_dataloader"
|
||||
]
|
||||
|
||||
from .mix_dataloader import MixDataLoader
|
||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
|
||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
|
||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
|
||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
|
||||
from .prepare_dataloader import prepare_dataloader
|
@ -1,6 +1,8 @@
|
||||
__all__ = [
|
||||
"TorchDataLoader",
|
||||
"prepare_torch_dataloader"
|
||||
"prepare_torch_dataloader",
|
||||
"MixDataLoader"
|
||||
]
|
||||
|
||||
from .fdl import TorchDataLoader, prepare_torch_dataloader
|
||||
from .mix_dataloader import MixDataLoader
|
||||
|
@ -168,6 +168,7 @@ from fastNLP.core.collators import Collator
|
||||
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
|
||||
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
|
||||
from ..log import logger
|
||||
from fastNLP.core.utils.dummy_class import DummyClass
|
||||
|
||||
|
||||
progress_bars = {
|
||||
@ -231,7 +232,8 @@ def _multi_proc(ds, _apply_field, func, counter, queue):
|
||||
"""
|
||||
idx = -1
|
||||
import contextlib
|
||||
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁
|
||||
null = DummyClass()
|
||||
with contextlib.redirect_stdout(null): # 避免打印触发 rich 的锁
|
||||
logger.set_stdout(stdout='raw')
|
||||
results = []
|
||||
try:
|
||||
@ -597,7 +599,8 @@ class DataSet:
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``;
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
@ -631,8 +634,14 @@ class DataSet:
|
||||
:param field_name: 传入func的是哪个field。
|
||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符
|
||||
:return Dict[str:Field]: 返回一个字典
|
||||
"""
|
||||
@ -672,7 +681,13 @@ class DataSet:
|
||||
progress_bar: str = 'rich', _apply_field: str = None,
|
||||
progress_desc: str = 'Main') -> list:
|
||||
"""
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance``
|
||||
:param _apply_field: 需要传进去func的数据集的field_name
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
@ -744,7 +759,13 @@ class DataSet:
|
||||
|
||||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
|
||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:return Dict[str:Field]: 返回一个字典
|
||||
@ -789,7 +810,13 @@ class DataSet:
|
||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
|
||||
盖之前的field。如果为None则不创建新的field。
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_desc: progress bar 显示的值,默认为空。
|
||||
"""
|
||||
|
@ -175,6 +175,14 @@ class Driver(ABC):
|
||||
raise NotImplementedError(
|
||||
"Each specific driver should implemented its own `_check_optimizer_legality` function.")
|
||||
|
||||
def check_dataloader_legality(self, dataloader):
|
||||
"""
|
||||
检测 DataLoader 是否合法,如果不合法,会 raise TypeError 。
|
||||
|
||||
:param dataloder:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def set_optimizers(self, optimizers=None):
|
||||
r"""
|
||||
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上;
|
||||
|
@ -13,6 +13,7 @@ if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor.optim import Optimizer
|
||||
from jittor.dataset import Dataset
|
||||
|
||||
_reduces = {
|
||||
'max': jt.max,
|
||||
@ -52,21 +53,11 @@ class JittorDriver(Driver):
|
||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
|
||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
|
||||
|
||||
@staticmethod
|
||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
||||
def check_dataloader_legality(self, dataloader):
|
||||
# 在fastnlp中实现了JittorDataLoader
|
||||
# TODO: 是否允许传入Dataset?
|
||||
if is_train:
|
||||
if not isinstance(dataloader, JittorDataLoader):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.")
|
||||
else:
|
||||
if not isinstance(dataloader, Dict):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
|
||||
else:
|
||||
for each_dataloader in dataloader.values():
|
||||
if not isinstance(each_dataloader, JittorDataLoader):
|
||||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' "
|
||||
f"type, not {type(each_dataloader)}.")
|
||||
if not isinstance(dataloader, Dataset):
|
||||
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _check_optimizer_legality(optimizers):
|
||||
|
@ -94,29 +94,9 @@ class PaddleDriver(Driver):
|
||||
self.grad_scaler.step(optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
||||
@staticmethod
|
||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
||||
if is_train:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.")
|
||||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
|
||||
if isinstance(dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
if dataloader.batch_sampler is None and dataloader.batch_size is None:
|
||||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.")
|
||||
else:
|
||||
if not isinstance(dataloader, Dict):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
|
||||
else:
|
||||
for each_dataloader in dataloader.values():
|
||||
if not isinstance(each_dataloader, DataLoader):
|
||||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'paddle.io.DataLoader' "
|
||||
f"type, not {type(each_dataloader)}.")
|
||||
if isinstance(each_dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None:
|
||||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
|
||||
f"`batch_sampler` and `batch_size` should be set.")
|
||||
def check_dataloader_legality(self, dataloader):
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
|
||||
|
||||
@staticmethod
|
||||
def _check_optimizer_legality(optimizers):
|
||||
|
@ -91,26 +91,9 @@ class TorchDriver(Driver):
|
||||
self.grad_scaler.step(optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
||||
@staticmethod
|
||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
||||
if is_train:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.")
|
||||
|
||||
# todo 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
|
||||
if isinstance(dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
|
||||
else:
|
||||
if not isinstance(dataloader, Dict):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
|
||||
else:
|
||||
for each_dataloader in dataloader.values():
|
||||
if not isinstance(each_dataloader, DataLoader):
|
||||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'DataLoader' "
|
||||
f"type, not {type(each_dataloader)}.")
|
||||
if isinstance(each_dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
def check_dataloader_legality(self, dataloader):
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
|
||||
|
||||
@staticmethod
|
||||
def _check_optimizer_legality(optimizers):
|
||||
|
@ -3,3 +3,9 @@ __all__ = []
|
||||
class DummyClass:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __getattr__(self, item):
|
||||
return lambda *args, **kwargs: ...
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
@ -4,7 +4,8 @@ __all__ = [
|
||||
|
||||
import uuid
|
||||
import sys
|
||||
from ...envs.imports import _module_available, _compare_version
|
||||
from ...envs.utils import _module_available, _compare_version, _get_version
|
||||
|
||||
from ...envs import get_global_rank
|
||||
from .utils import is_notebook
|
||||
from ..log import logger
|
||||
@ -82,8 +83,10 @@ class TqdmProgress(metaclass=Singleton):
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \
|
||||
f"To use tqdm, tqdm>=4.57 is needed."
|
||||
if not _module_available('tqdm'):
|
||||
raise ModuleNotFoundError("Package tqdm is not installed.")
|
||||
elif not _compare_version('tqdm', operator.ge, '4.57'):
|
||||
raise RuntimeError(f"Package tqdm>=4.57 is needed, instead of {_get_version('tqdm')}.")
|
||||
|
||||
from .rich_progress import f_rich_progress
|
||||
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop."
|
||||
|
@ -26,6 +26,25 @@ def _module_available(module_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_version(package, use_base_version: bool = False):
|
||||
try:
|
||||
pkg = importlib.import_module(package)
|
||||
except (ModuleNotFoundError, DistributionNotFound):
|
||||
return False
|
||||
try:
|
||||
if hasattr(pkg, "__version__"):
|
||||
pkg_version = Version(pkg.__version__)
|
||||
else:
|
||||
# try pkg_resources to infer version
|
||||
pkg_version = Version(pkg_resources.get_distribution(package).version)
|
||||
except TypeError:
|
||||
# this is mocked by Sphinx, so it should return True to generate all summaries
|
||||
return True
|
||||
if use_base_version:
|
||||
pkg_version = Version(pkg_version.base_version)
|
||||
return pkg_version
|
||||
|
||||
|
||||
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
|
||||
"""Compare package version with some requirements.
|
||||
|
||||
|
@ -231,7 +231,13 @@ class DataBundle:
|
||||
盖之前的field。如果为None则不创建新的field。
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
|
||||
:param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
@ -260,10 +266,16 @@ class DataBundle:
|
||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param str field_name: 传入func的是哪个field。
|
||||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
|
||||
如果为False,则报错
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。
|
||||
|
||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
|
||||
@ -292,8 +304,14 @@ class DataBundle:
|
||||
:param callable func: input是instance中名为 `field_name` 的field的内容。
|
||||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
|
||||
盖之前的field。如果为None则不创建新的field。
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称
|
||||
|
||||
"""
|
||||
@ -316,8 +334,14 @@ class DataBundle:
|
||||
|
||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
|
||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
|
||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param num_proc: 使用进程的数量。
|
||||
|
||||
.. note::
|
||||
|
||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时,
|
||||
``func`` 函数中的打印将不会输出。
|
||||
|
||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
|
||||
:param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称
|
||||
|
||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
|
||||
@ -382,4 +406,3 @@ class DataBundle:
|
||||
for name, vocab in self.vocabs.items():
|
||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
|
||||
return _str
|
||||
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
|
||||
from fastNLP import logger
|
||||
|
||||
|
||||
class TestDataSetInit:
|
||||
@ -379,6 +380,14 @@ class TestDataSetMethods:
|
||||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
|
||||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0)
|
||||
|
||||
def test_apply_more_proc(self):
|
||||
def func(x):
|
||||
print("x")
|
||||
logger.info("demo")
|
||||
return len(x)
|
||||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
|
||||
data.apply_field(func, field_name='x', new_field_name='len_x', num_proc=2)
|
||||
|
||||
|
||||
class TestFieldArrayInit:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user