Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
x54-729 2022-04-16 05:40:35 +00:00
commit a25a73394b
29 changed files with 264 additions and 351 deletions

View File

@ -138,10 +138,6 @@ class CheckpointCallback(HasMonitorCallback):
f'exception_{exception.__class__.__name__}'
self.save(trainer=trainer, folder_name=folder_name)
def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
self.get_monitor_value(results=sanity_check_res)
def on_save_checkpoint(self, trainer) -> Dict:
"""
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹

View File

@ -49,7 +49,8 @@ class HasMonitorCallback(Callback):
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if self.monitor is not None:
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
@ -71,6 +72,12 @@ class HasMonitorCallback(Callback):
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")
def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
if self.monitor is not None:
self.get_monitor_value(results=sanity_check_res)
def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值如果 monitor 没有直接找到会尝试使用匹配的方式寻找并把匹配到的设置到 self._real_monitor 属性上

View File

@ -10,7 +10,7 @@ import shutil
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
from fastNLP.core.log import logger
from fastNLP.envs import all_rank_call
from fastNLP.envs import all_rank_call_context
class LoadBestModelCallback(HasMonitorCallback):
@ -76,9 +76,6 @@ class LoadBestModelCallback(HasMonitorCallback):
super().on_after_trainer_initialized(trainer, driver)
def on_sanity_check_end(self, trainer, sanity_check_res):
self.get_monitor_value(sanity_check_res)
def on_validate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder:
@ -86,7 +83,7 @@ class LoadBestModelCallback(HasMonitorCallback):
model_save_fn=self.model_save_fn)
else:
self.buffer.seek(0)
with all_rank_call():
with all_rank_call_context():
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
def on_train_end(self, trainer):

View File

@ -11,14 +11,15 @@ class LRSchedCallback(Callback):
根据 step_on 参数在合适的时机调用 scheduler step 函数
:param scheduler: 实现了 step() 函数的对象
:param step_on: 可选 ['batch' 'epoch'] 表示在何时调用 scheduler step 函数
:param step_on: 可选 ['batch' 'epoch'] 表示在何时调用 scheduler step 函数如果为 batch 的话在每次更新参数
之前调用如果为 epoch 则是在一个 epoch 运行结束后调用
"""
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \
"step function."
self.scheduler = scheduler
self.step_on = 0 if step_on == 'batch' else 1
def on_train_batch_end(self, trainer):
def on_before_optimizers_step(self, trainer, optimizers):
if self.step_on == 0:
self.scheduler.step()

View File

@ -32,10 +32,6 @@ class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()
def on_sanity_check_end(self, trainer, sanity_check_res):
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None:
self.get_monitor_value(sanity_check_res)
class RichCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,

View File

@ -3,7 +3,6 @@ from functools import partial
from dataclasses import is_dataclass
import sys
__all__ = [
'Evaluator'
]
@ -75,8 +74,8 @@ class Evaluator:
auto_tensor_conversion_for_metric 为True时fastNLP 将自动将输出中 paddle tensor 其它非 tensor 的参数
不做任何处理转换为 pytorch tensor 再输入到 metrics 中进行评测 model 的输出 tensor 类型通过 driver 来决定
metrics 支持的输入类型由 metrics 决定如果需要更复杂的转换请使用 input_mappingoutput_mapping 参数进行
use_dist_sampler: 是否使用分布式evaluate的方式仅当 driver 为分布式类型时该参数才有效如果为True将使得每个进程上
dataloader 自动使用不同数据所有进程的数据并集是整个数据集请确保使用的 metrics 支持自动分布式累积
use_dist_sampler: 是否使用分布式evaluate的方式仅当 driver 为分布式类型时该参数才有效默认为根据 driver 是否支持
分布式进行设置如果为True将使得每个进程上dataloader 自动使用不同数据所有进程的数据并集是整个数据集
output_from_new_proc: 应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一
["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到
log 文件中然后将 log 文件保存在通过该参数值设定的文件夹中默认为 "only_error"
@ -86,7 +85,8 @@ class Evaluator:
self.model = model
self.metrics = metrics
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call,
**kwargs)
if dataloaders is None:
raise ValueError("Parameter `dataloaders` can not be None.")
@ -105,9 +105,13 @@ class Evaluator:
dataloaders = {None: dataloaders}
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
self.driver.setup()
self.driver.barrier()
self.separator = kwargs.get('separator', '#')
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed())
if use_dist_sampler:
self._dist_sampler = "unrepeatdist"
else:
@ -115,8 +119,9 @@ class Evaluator:
self._metric_wrapper = None
_ = self.metrics_wrapper # 触发检查
self.driver.setup()
self.driver.barrier()
if self._dist_sampler is not None and not self.driver.is_distributed():
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause "
"different process evaluating on different data.")
if evaluate_fn is not None and not isinstance(evaluate_fn, str):
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.")
@ -183,7 +188,7 @@ class Evaluator:
return metric_results
def start_progress_bar(self, total:int, dataloader_name):
def start_progress_bar(self, total: int, dataloader_name):
if self.progress_bar == 'rich':
if dataloader_name is None:
desc = f'Eval. Batch:0'
@ -208,7 +213,7 @@ class Evaluator:
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True),
visible=kwargs.get('visible', True))
elif self.progress_bar == 'raw':
if self.verbose>1:
if self.verbose > 1:
logger.info(desc)
def remove_progress_bar(self, dataloader_name):
@ -256,7 +261,7 @@ class Evaluator:
"""
self.metrics_wrapper.update(*args, **kwargs)
def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict:
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict:
"""
获取当前dataloader的metric结果
@ -313,6 +318,7 @@ class _MetricsWrapper:
并且通过对 update() , reset() , get_metric() 函数的封装实现支持 fastNLP metric 以及 torchmetrics 或者更多
"""
def __init__(self, metrics, evaluator):
self.evaluator = evaluator
self._metrics = []
@ -326,13 +332,14 @@ class _MetricsWrapper:
# torchmetrics 是默认自动开启了多卡的
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device)
elif isinstance(metric, Metric):
if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \
and metric.aggregate_when_get_metric is False:
logger.warning("You have replace the sampler as distributed sampler when evaluation, but your "
f"metric:{metric_name}' `aggregate_when_get_metric` is False.")
if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \
and metric.aggregate_when_get_metric is True:
pass # 这种情况无所谓,因为
# 如果数据是分布式的但是不aggregate的话可能有问题
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
logger.warning_once(
"You have replace the sampler as distributed sampler when evaluation, but your "
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.")
if metric.aggregate_when_get_metric is None:
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None
metric.to(evaluator.driver.data_device)
self._metric_names.append(metric_name)
self._metrics.append(metric)
@ -343,8 +350,9 @@ class _MetricsWrapper:
for metric in self._metrics:
args = []
if not isinstance(batch, dict):
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
logger.warning_once(
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
else:
args.append(batch)
if not isinstance(outputs, dict):
@ -368,7 +376,7 @@ class _MetricsWrapper:
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric):
metric.reset()
def get_metric(self, dataloader_name:str, separator:str) -> Dict:
def get_metric(self, dataloader_name: str, separator: str) -> Dict:
"""
将所有 metric 结果展平到一个一级的字典中这个字典中 key 的命名规则是
indicator_name{separator}metric_name{separator}dataloader_name
@ -419,4 +427,4 @@ def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indic
names.append(dataloader_name)
if len(names) == 0:
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.")
return separator.join(names)
return separator.join(names)

View File

@ -122,7 +122,8 @@ class Trainer(TrainerEventTrigger):
注意如果 model_device None那么 data_device 不会起作用
torch_ddp_kwargs: 用于配置 pytorch DistributedDataParallel 初始化时的参数
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None
use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader sampler 替换为分布式的 sampler默认为 True
use_dist_sampler: 表示是否使用分布式的 sampler 在多卡时分布式 sampler 将自动决定每张卡上读取的 sample 使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample默认会根据 driver 是否为分布式进行设置
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader sampler 替换为分布式的 sampler默认为 True
output_from_new_proc: 应当为一个字符串表示在多进程的 driver 中其它进程的输出流应当被做如何处理其值应当为以下之一
["all", "ignore", "only_error"]当该参数的值不是以上值时该值应当表示一个文件夹的名字我们会将其他 rank 的输出流重定向到
@ -211,12 +212,6 @@ class Trainer(TrainerEventTrigger):
total_batches=None
)
use_dist_sampler = kwargs.get("use_dist_sampler", True)
if use_dist_sampler:
_dist_sampler = "dist"
else:
_dist_sampler = None
""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")
@ -224,6 +219,18 @@ class Trainer(TrainerEventTrigger):
if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")
self.metrics = metrics
self.validate_every = evaluate_every
self.driver.setup()
self.driver.barrier()
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
if use_dist_sampler:
_dist_sampler = "dist"
else:
_dist_sampler = None
self.evaluator = None
self.monitor = monitor
self.larger_better = larger_better
@ -241,16 +248,10 @@ class Trainer(TrainerEventTrigger):
output_mapping=output_mapping,
fp16=fp16,
verbose=0,
use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler),
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None),
progress_bar=kwargs.get('progress_bar', 'auto')
)
self.metrics = metrics
self.validate_every = evaluate_every
self.driver.setup()
self.driver.barrier()
if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
@ -753,7 +754,7 @@ class Trainer(TrainerEventTrigger):
"""
if (self.global_forward_batches + 1) % self.accumulation_steps != 0:
_no_sync_context = self.driver.get_no_sync_context()
_no_sync_context = self.driver.get_model_no_sync_context()
else:
_no_sync_context = nullcontext

View File

@ -199,9 +199,10 @@ class Driver(ABC):
"""
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.")
def get_no_sync_context(self):
def get_model_no_sync_context(self):
r"""
返回一个用于关闭多进程之间互相同步操作的 context 上下文对象只有多卡的 driver 需要单独实现该函数单卡的 driver 不需要
返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象只有多卡的 driver 需要单独实现该函数
单卡的 driver 不需要
:return: 返回一个类似于 DistributedDataParallel(model).no_sync context 上下文对象
"""
@ -357,6 +358,8 @@ class Driver(ABC):
r"""
用于在多进程工作时同步各进程的工作进度运行快的进程运行到这里会等待运行慢的进程只有所有进程都运行到此函数时所有的进程才会继续运行
仅在多分布式训练场景中有使用
注意该函数的行为会受到 FASTNLP_NO_SYNC 的影响仅当 FASTNLP_NO_SYNC os.environ 中不存在或小于 1 时才真的执行 barrier
"""
def is_distributed(self) -> bool:

View File

@ -82,7 +82,7 @@ class JittorMPIDriver(JittorDriver):
def is_global_zero(self):
return self.global_rank == 0
def get_no_sync_context(self):
def get_model_no_sync_context(self):
return self.model.no_sync
def unwrap_model(self):

View File

@ -403,7 +403,7 @@ class PaddleFleetDriver(PaddleDriver):
def is_global_zero(self):
return self.global_rank == 0
def get_no_sync_context(self):
def get_model_no_sync_context(self):
return self.model.no_sync
def unwrap_model(self):

View File

@ -5,7 +5,6 @@ import socket
import numpy as np
from time import sleep
from typing import List, Optional, Union, Dict, Tuple, Callable
from functools import partial
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
@ -29,7 +28,7 @@ from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
@ -511,7 +510,7 @@ class TorchDDPDriver(TorchDriver):
def is_global_zero(self):
return self.global_rank == 0
def get_no_sync_context(self):
def get_model_no_sync_context(self):
# 注意此时的 model 是 "DistributedDataParallel" 对象;
return self.model.no_sync
@ -526,7 +525,8 @@ class TorchDDPDriver(TorchDriver):
return self.local_rank
def barrier(self):
torch.distributed.barrier(async_op=True)
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
torch.distributed.barrier(async_op=True)
def is_distributed(self):
return True
@ -544,6 +544,8 @@ class TorchDDPDriver(TorchDriver):
:return: 如果当前不是分布式 driver 直接返回输入的 obj 如果当前 rank 是接收端 global rank 包含在了 dst 则返回
接收到的参数如果是 source 端则返回发射的内容既不是发送端又不是接收端则返回 None
"""
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group)
def all_gather(self, obj, group) -> List:
@ -569,6 +571,8 @@ class TorchDDPDriver(TorchDriver):
:param group:
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_torch_all_gather(obj, group=group)

View File

@ -1,5 +1,6 @@
import io
import pickle
import os
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List
@ -7,6 +8,7 @@ from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.env import FASTNLP_NO_SYNC
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
@ -34,47 +36,15 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
)
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank
Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.
Args:
obj (Any): Input object. Must be picklable.
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
should be correctly sized as the size of the group for this
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
output of the collective.
.. note:: Note that this API differs slightly from the gather collective
since it does not provide an async_op handle and thus will be a blocking
call.
.. note:: Note that this API is not supported when using the NCCL backend.
.. warning::
:func:`gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.
Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
>>> fastnlp_torch_gather_object(
gather_objects[dist.get_rank()],
output if dist.get_rank() == 0 else None,
dst=0
@ -82,7 +52,20 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAU
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]
:param obj: 需要发送的 obj 对象需要是可以 pickable 的对象
:param dst: 目标的 rank
:param group: 在哪个 group 执行该函数
:return: dst 上面返回 world_size list依次为 rank 0rank 1... obj
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if dist.get_rank() == dst:
object_gather_list = [None for _ in range(dist.get_world_size(group))]
else:
object_gather_list = None
if group is None:
group = DEFAULT_TORCH_GROUP
@ -212,6 +195,9 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) -
:param group:
:return: 返回的结果是 [obj0, obj1, ...]其中 obj_i 即为第 i rank 上的 obj
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
@ -232,12 +218,18 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
"""
src 上的 obj 对象广播到其它 rank
:param obj:
:param src:
:param obj: 需要发送的对象
:param src: 从哪里发出
:param device:
:param group:
:param group: 属于哪个通信 group
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
if src == dist.get_rank(group):
return obj
else:
return None
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
@ -289,50 +281,23 @@ def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码使得可以版本兼容低版本的 pytorch
Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.
Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.
.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.
.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.
.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.
Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
:param object_list:
:param obj:
:param group:
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:

View File

@ -35,7 +35,6 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
"'jittor', 'paddle', 'fleet'].")
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None):
"""
使用 command 通过 subprocess.Popen 开启新的进程
@ -60,30 +59,3 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:
err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w')
proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f)
return proc
def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs):
r"""
对应 `load_model`用来帮助用户加载之前通过 `load_model` 所保存的模型
:param filepath: 加载的文件的位置
:param backend: 使用哪种 backend 来加载该 filepath 目前支持 ["torch", "paddle", "jittor"]
"""
if filepath is None:
raise ValueError("Parameter `path` can not be None.")
assert backend is not None, "Parameter `backend` can not be None."
if backend == "torch":
import torch
_res = torch.load(filepath)
return _res
elif backend == "jittor":
raise NotImplementedError
elif backend == "paddle":
raise NotImplementedError
else:
raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']")

View File

@ -7,7 +7,6 @@ __all__ = [
'TorchBackend',
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',
'func_post_proc'
]
from .metric import Metric
@ -15,4 +14,3 @@ from .accuracy import Accuracy
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
from .utils import func_post_proc

View File

@ -13,15 +13,22 @@ from fastNLP.core.utils.utils import seq_len_to_mask
class Accuracy(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
"""
计算 准确率 metric
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 一般情况下直接使用 'auto' 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义如果为 None 将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置
"""
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
def get_metric(self) -> dict:
r"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
get_metric 函数将根据 evaluate 函数累计的评价指标统计量来计算最终的评价结果.
:return dict evaluate_result: {"acc": float}
"""

View File

@ -3,35 +3,32 @@ __all__ = [
]
from typing import Union, List
from collections import defaultdict
from functools import partial
from collections import Counter
import warnings
from .metric import Metric
from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.utils import seq_len_to_mask
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec
from .utils import _compute_f_pre_rec
class ClassifyFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0,
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = False) -> None:
aggregate_when_get_metric: bool = None) -> None:
"""
:param tag_vocab:
:param ignore_labels:
:param only_gross:
:param f_type:
:param beta:
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 大部分情况下直接使用 auto 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义如果为 None 将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置
"""
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
@ -47,32 +44,15 @@ class ClassifyFPreRecMetric(Metric):
self.tag_vocab = tag_vocab
self._tp = {}
self._fp = {}
self._fn = {}
if tag_vocab:
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
elif num_class > 0:
for word in range(num_class):
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
else:
raise ValueError()
self._tp = Counter()
self._fp = Counter()
self._fn = Counter()
def reset(self):
# 由于不是 element 了,需要自己手动清零一下
self._tp.clear()
self._fp.clear()
self._fn.clear()
def get_metric(self) -> dict:
r"""
@ -81,10 +61,22 @@ class ClassifyFPreRecMetric(Metric):
:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {}
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
if not self.only_gross or self.f_type == 'macro':
tags = set(self._fn.keys())
tags.update(set(self._fp.keys()))
tags.update(set(self._tp.keys()))
tags = set(_fn.keys())
tags.update(set(_fp.keys()))
tags.update(set(_tp.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
@ -93,9 +85,9 @@ class ClassifyFPreRecMetric(Metric):
tag_name = self.tag_vocab.to_word(tag)
else:
tag_name = int(tag)
tp = self._tp[tag].get_scalar()
fn = self._fn[tag].get_scalar()
fp = self._fp[tag].get_scalar()
tp = _tp[tag]
fn = _fn[tag]
fp = _fp[tag]
if tp == fn == fp == 0:
continue
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
@ -116,10 +108,7 @@ class ClassifyFPreRecMetric(Metric):
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(val.get_scalar() for val in self._tp.values()),
sum(val.get_scalar() for val in self._fn.values()),
sum(val.get_scalar() for val in self._fp.values()))
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec

View File

@ -35,6 +35,8 @@ class Element:
"""
self._check_value_initialized()
if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。
return
try:
self._value = self.backend.aggregate(self._value, self.aggregate_method)
except AggregateMethodError as e:

View File

@ -14,13 +14,13 @@ from fastNLP.core.metrics.element import Element
class Metric:
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
"""
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 大部分情况下直接使用 auto 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义
backend 不支持分布式时该参数无意义如果为 None 将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置
"""
self.backend = AutoBackend(backend)
self._updated = False
@ -43,7 +43,7 @@ class Metric:
:param name: 当前 element 的名字注册后 Metric 中可以通过 self.{name} 访问该变量
:param value: 初始化的值在调用 Metric.reset() 方法时也将自动设置为该值
:param aggregate_method: 如何聚合多卡上的结果如果为单卡执行该值无意义
:param aggregate_method: 如何聚合多卡上的结果如果为单卡执行该值无意义如果设置为 None 则表示该 element 不进行聚合
:param backend: 使用的 backend Element 的类型会根据 backend 进行实际的初始化例如 backend torch 则该对象为
Torch.tensor 如果backend paddle 则该对象为 paddle.tensor 如果 backend jittor , 则该对象为 jittor.Var
一般情况下直接默认为 auto 就行了fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化例如当传入

View File

@ -4,12 +4,12 @@ __all__ = [
from typing import Union, List, Optional
import warnings
from collections import defaultdict
from functools import partial
from collections import Counter
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary
from .utils import _compute_f_pre_rec
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
@ -199,26 +199,11 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec
class SpanFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro',
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None:
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None:
r"""
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 支持的标签为"B"(没有label)"B-xxx"(xxx为某种label比如POS中的NN)
@ -234,7 +219,7 @@ class SpanFPreRecMetric(Metric):
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 一般情况下直接使用 'auto' 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义
backend 不支持分布式时该参数无意义如果为 None 将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置
"""
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric):
self.only_gross = only_gross
self.tag_vocab = tag_vocab
self._true_positives = {}
self._false_positives = {}
self._false_negatives = {}
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend)
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend)
self._tp = Counter()
self._fp = Counter()
self._fn = Counter()
def reset(self):
self._tp.clear()
self._fp.clear()
self._fn.clear()
def get_metric(self) -> dict:
evaluate_result = {}
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(self._false_positives.keys())
tags.update(self._true_positives.keys())
tags = set(_fn.keys())
tags.update(_fp.keys())
tags.update(_tp.keys())
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
tp = _tp[tag]
fn = _fn[tag]
fp = _fp[tag]
if tp == fn == fp == 0:
continue
@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric):
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
tp, fn, fp = [], [], []
for val in self._true_positives.values():
tp.append(val.get_scalar())
for val in self._false_negatives.values():
fn.append(val.get_scalar())
for val in self._false_positives.values():
fp.append(val.get_scalar())
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(tp),
sum(fn),
sum(fp))
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric):
for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
self._tp[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
self._fp[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1
self._fn[span[0]] += 1

View File

@ -1,5 +1,4 @@
__all__ = [
'func_post_proc'
]
from typing import Any
@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool:
return False
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
"""
将fn函数作用包裹在 metric 对象的 {method_name} 方法上使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
后再返回注意对 metric {method_name} 函数的修改是 inplace
:param metric: metric对象
:param fn: 作用于 metric accumulate 方法的返回值
:param method_name: 一般来说对于
:return: metric
"""
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
f"Parameter `metric` must have a {method_name} function."
assert callable(fn), "Parameter `fn` must be callable."
func = getattr(metric, method_name)
@wraps(func)
def wrap_method(*args, **kwargs):
res = func(*args, **kwargs)
return fn(res)
wrap_method.__wrapped_by_func_post_proc__ = True
setattr(metric, method_name, wrap_method)
return metric
class AggregateMethodError(BaseException):
def __init__(self, should_have_aggregate_method, only_warn=False):
super(AggregateMethodError, self).__init__(self)
self.should_have_aggregate_method = should_have_aggregate_method
self.only_warn = only_warn
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec

View File

@ -222,14 +222,14 @@ class RandomSampler(ReproducibleSampler):
class SequentialSampler(RandomSampler):
def __init__(self, dataset, dist_mode:str='interval', **kwargs):
def __init__(self, dataset, **kwargs):
"""
按照顺序读取 dataset 在多卡情况下间隔读取例如在两卡情况下卡0取 [0,2,4,..], 卡1取 [1,3,5...]
:param dataset: 实现了 __len__ 方法的数据容器
:param kwargs:
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
super().__init__(dataset=dataset, **kwargs)
def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True说明之前的还没结束只有强制重新初始化了

View File

@ -6,8 +6,9 @@ __all__ = [
'is_cur_env_distributed',
'get_global_rank',
'rank_zero_call',
'all_rank_call',
'get_gpu_count'
'all_rank_call_context',
'get_gpu_count',
'fastnlp_no_sync_context'
]

View File

@ -7,10 +7,11 @@ __all__ = [
'is_cur_env_distributed',
'get_global_rank',
'rank_zero_call',
'all_rank_call'
'all_rank_call_context',
'fastnlp_no_sync_context'
]
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC
def is_cur_env_distributed() -> bool:
@ -41,24 +42,46 @@ def rank_zero_call(fn: Callable):
return a+b
rank_zero_call(add)(1, 2)
同时该函数还会设置 FASTNLP_NO_SYNC 2在这个环境下所有的 fastNLP 内置的 barrier 接口gather/broadcast 操作都没有任何
意义
:param fn: 需要包裹的可执行的函数
:return:
"""
@wraps(fn)
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
return fn(*args, **kwargs)
with fastnlp_no_sync_context(level=2):
return fn(*args, **kwargs)
return None
return wrapped_fn
@contextmanager
def all_rank_call():
def fastnlp_no_sync_context(level=2):
"""
用于让 fastNLP barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序如果为 1 表示 fastNLP 里的barrier 操作失效
如果为 2 表示 barrier gather/broadcast 都失效
:param int level: 可选 [0, 1, 2]
:return:
"""
old_level = os.environ.get(FASTNLP_NO_SYNC, None)
os.environ[FASTNLP_NO_SYNC] = f'{level}'
yield
if old_level is None:
os.environ.pop(FASTNLP_NO_SYNC)
else:
os.environ[FASTNLP_NO_SYNC] = old_level
@contextmanager
def all_rank_call_context():
"""
在多卡模式下该环境内会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0"使得 rank_zero_call 函数失效使得每个进程都会运行该函数
# 使用方式
with all_rank_run():
with all_rank_call_context():
do_something # all rank will do
:param fn:

View File

@ -48,6 +48,10 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"
# fastNLP 中初始化deque的默认大小
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行;
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC'
# todo 注释 直接使用的变量
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar"
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar"

View File

@ -85,7 +85,7 @@ class TestFleetDriverFunction:
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_no_sync_context()
res = self.driver.get_model_no_sync_context()
dist.barrier()
@magic_argv_env_context

View File

@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric:
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
0, 3, 0, 0, 0, 1, 3, 1])
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5)
metric = ClassifyFPreRecMetric(f_type='macro')
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.1882051282051282
@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric:
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5)
metric = ClassifyFPreRecMetric(f_type='micro')
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.21875
@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric:
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5)
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
metric.update(pred, target)
result_dict = metric.get_metric()
ground_truth = {
@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric:
})
metric_kwargs = {
'f_type': f_type,
'num_class': 5,
'only_gross': False,
'aggregate_when_get_metric': True
}

View File

@ -102,7 +102,8 @@ class TestSpanFPreRecMetric:
# bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False,
aggregate_when_get_metric=True)
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,

View File

@ -1,32 +0,0 @@
import unittest
from fastNLP.core.metrics.utils import func_post_proc
class Metric:
def accumulate(self, x, y):
return x, y
def compute(self, x, y):
return x, y
class TestMetricUtil(unittest.TestCase):
def test_func_post_proc(self):
metric = Metric()
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate')
self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2))
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate')
self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2))
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2))
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update')
self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2))
def test_check_accumulate_post_special_local_variable(self):
metric = Metric()
self.assertFalse(hasattr(metric, '__wrapped_by_fn__'))
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))

View File

@ -1,6 +1,6 @@
import os
from fastNLP.envs.distributed import rank_zero_call, all_rank_call
from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context
@ -70,7 +70,7 @@ class TestTorch:
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
# torch.distributed.init_process_group(backend='nccl')
# torch.distributed.barrier()
with all_rank_call():
with all_rank_call_context():
with Capturing(no_del=True) as output:
write_something()
output = output[0]
@ -80,7 +80,7 @@ class TestTorch:
else:
assert '11111' in output
with all_rank_call():
with all_rank_call_context():
with Capturing(no_del=True) as output:
rank_zero_call(write_other_thing)()