mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
a25a73394b
@ -138,10 +138,6 @@ class CheckpointCallback(HasMonitorCallback):
|
||||
f'exception_{exception.__class__.__name__}'
|
||||
self.save(trainer=trainer, folder_name=folder_name)
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
# 主要核对一下 monitor 是否存在。
|
||||
self.get_monitor_value(results=sanity_check_res)
|
||||
|
||||
def on_save_checkpoint(self, trainer) -> Dict:
|
||||
"""
|
||||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。
|
||||
|
@ -49,7 +49,8 @@ class HasMonitorCallback(Callback):
|
||||
self.monitor = monitor
|
||||
else:
|
||||
self.monitor = str(monitor) if monitor is not None else None
|
||||
self.larger_better = bool(larger_better)
|
||||
if self.monitor is not None:
|
||||
self.larger_better = bool(larger_better)
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
@ -71,6 +72,12 @@ class HasMonitorCallback(Callback):
|
||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
|
||||
f"You can set it in the initialization or through Trainer.")
|
||||
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
# 主要核对一下 monitor 是否存在。
|
||||
if self.monitor is not None:
|
||||
self.get_monitor_value(results=sanity_check_res)
|
||||
|
||||
def get_monitor_value(self, results:Dict)->Union[float, None]:
|
||||
"""
|
||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。
|
||||
|
@ -10,7 +10,7 @@ import shutil
|
||||
|
||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import all_rank_call
|
||||
from fastNLP.envs import all_rank_call_context
|
||||
|
||||
|
||||
class LoadBestModelCallback(HasMonitorCallback):
|
||||
@ -76,9 +76,6 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
|
||||
super().on_after_trainer_initialized(trainer, driver)
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
self.get_monitor_value(sanity_check_res)
|
||||
|
||||
def on_validate_end(self, trainer, results):
|
||||
if self.is_better_results(results, keep_if_better=True):
|
||||
if self.real_save_folder:
|
||||
@ -86,7 +83,7 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
model_save_fn=self.model_save_fn)
|
||||
else:
|
||||
self.buffer.seek(0)
|
||||
with all_rank_call():
|
||||
with all_rank_call_context():
|
||||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
|
||||
|
||||
def on_train_end(self, trainer):
|
||||
|
@ -11,14 +11,15 @@ class LRSchedCallback(Callback):
|
||||
根据 step_on 参数在合适的时机调用 scheduler 的 step 函数。
|
||||
|
||||
:param scheduler: 实现了 step() 函数的对象
|
||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数
|
||||
:param step_on: 可选 ['batch', 'epoch'] 表示在何时调用 scheduler 的 step 函数。如果为 batch 的话在每次更新参数
|
||||
之前调用;如果为 epoch 则是在一个 epoch 运行结束后调用。
|
||||
"""
|
||||
assert hasattr(scheduler, 'step') and callable(scheduler.step), "The scheduler object should have a " \
|
||||
"step function."
|
||||
self.scheduler = scheduler
|
||||
self.step_on = 0 if step_on == 'batch' else 1
|
||||
|
||||
def on_train_batch_end(self, trainer):
|
||||
def on_before_optimizers_step(self, trainer, optimizers):
|
||||
if self.step_on == 0:
|
||||
self.scheduler.step()
|
||||
|
||||
|
@ -32,10 +32,6 @@ class ProgressCallback(HasMonitorCallback):
|
||||
def on_train_end(self, trainer):
|
||||
f_rich_progress.stop()
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None:
|
||||
self.get_monitor_value(sanity_check_res)
|
||||
|
||||
|
||||
class RichCallback(ProgressCallback):
|
||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
|
||||
|
@ -3,7 +3,6 @@ from functools import partial
|
||||
from dataclasses import is_dataclass
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Evaluator'
|
||||
]
|
||||
@ -75,8 +74,8 @@ class Evaluator:
|
||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数
|
||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定,
|
||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。
|
||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。如果为True,将使得每个进程上
|
||||
的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。请确保使用的 metrics 支持自动分布式累积。
|
||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持
|
||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。
|
||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
|
||||
@ -86,7 +85,8 @@ class Evaluator:
|
||||
|
||||
self.model = model
|
||||
self.metrics = metrics
|
||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)
|
||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call,
|
||||
**kwargs)
|
||||
|
||||
if dataloaders is None:
|
||||
raise ValueError("Parameter `dataloaders` can not be None.")
|
||||
@ -105,9 +105,13 @@ class Evaluator:
|
||||
dataloaders = {None: dataloaders}
|
||||
|
||||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn)
|
||||
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
|
||||
self.separator = kwargs.get('separator', '#')
|
||||
self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True)
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", False) # 如果是 Evaluator 自身的默认值的话,应当为 False;
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed())
|
||||
if use_dist_sampler:
|
||||
self._dist_sampler = "unrepeatdist"
|
||||
else:
|
||||
@ -115,8 +119,9 @@ class Evaluator:
|
||||
self._metric_wrapper = None
|
||||
_ = self.metrics_wrapper # 触发检查
|
||||
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
if self._dist_sampler is not None and not self.driver.is_distributed():
|
||||
logger.warning_once("Running in a non-distributed driver, but with distributed sampler, it may cause "
|
||||
"different process evaluating on different data.")
|
||||
|
||||
if evaluate_fn is not None and not isinstance(evaluate_fn, str):
|
||||
raise TypeError("Parameter `evaluate_fn` can only be `str` type when it is not None.")
|
||||
@ -183,7 +188,7 @@ class Evaluator:
|
||||
|
||||
return metric_results
|
||||
|
||||
def start_progress_bar(self, total:int, dataloader_name):
|
||||
def start_progress_bar(self, total: int, dataloader_name):
|
||||
if self.progress_bar == 'rich':
|
||||
if dataloader_name is None:
|
||||
desc = f'Eval. Batch:0'
|
||||
@ -208,7 +213,7 @@ class Evaluator:
|
||||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True),
|
||||
visible=kwargs.get('visible', True))
|
||||
elif self.progress_bar == 'raw':
|
||||
if self.verbose>1:
|
||||
if self.verbose > 1:
|
||||
logger.info(desc)
|
||||
|
||||
def remove_progress_bar(self, dataloader_name):
|
||||
@ -256,7 +261,7 @@ class Evaluator:
|
||||
"""
|
||||
self.metrics_wrapper.update(*args, **kwargs)
|
||||
|
||||
def get_dataloader_metric(self, dataloader_name:Optional[str]='') -> Dict:
|
||||
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict:
|
||||
"""
|
||||
获取当前dataloader的metric结果
|
||||
|
||||
@ -313,6 +318,7 @@ class _MetricsWrapper:
|
||||
并且通过对 update() , reset() , get_metric() 函数的封装,实现支持 fastNLP 的 metric 以及 torchmetrics 或者更多。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, metrics, evaluator):
|
||||
self.evaluator = evaluator
|
||||
self._metrics = []
|
||||
@ -326,13 +332,14 @@ class _MetricsWrapper:
|
||||
# torchmetrics 是默认自动开启了多卡的
|
||||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device)
|
||||
elif isinstance(metric, Metric):
|
||||
if evaluator._dist_sampler is not None and evaluator.driver.is_distributed() \
|
||||
and metric.aggregate_when_get_metric is False:
|
||||
logger.warning("You have replace the sampler as distributed sampler when evaluation, but your "
|
||||
f"metric:{metric_name}' `aggregate_when_get_metric` is False.")
|
||||
if evaluator._dist_sampler is None and evaluator.driver.is_distributed() \
|
||||
and metric.aggregate_when_get_metric is True:
|
||||
pass # 这种情况无所谓,因为
|
||||
# 如果数据是分布式的,但是不aggregate的话可能有问题
|
||||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False:
|
||||
logger.warning_once(
|
||||
"You have replace the sampler as distributed sampler when evaluation, but your "
|
||||
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.")
|
||||
if metric.aggregate_when_get_metric is None:
|
||||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None
|
||||
|
||||
metric.to(evaluator.driver.data_device)
|
||||
self._metric_names.append(metric_name)
|
||||
self._metrics.append(metric)
|
||||
@ -343,8 +350,9 @@ class _MetricsWrapper:
|
||||
for metric in self._metrics:
|
||||
args = []
|
||||
if not isinstance(batch, dict):
|
||||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
|
||||
f"the output of model to update metric.")
|
||||
logger.warning_once(
|
||||
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
|
||||
f"the output of model to update metric.")
|
||||
else:
|
||||
args.append(batch)
|
||||
if not isinstance(outputs, dict):
|
||||
@ -368,7 +376,7 @@ class _MetricsWrapper:
|
||||
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric):
|
||||
metric.reset()
|
||||
|
||||
def get_metric(self, dataloader_name:str, separator:str) -> Dict:
|
||||
def get_metric(self, dataloader_name: str, separator: str) -> Dict:
|
||||
"""
|
||||
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是
|
||||
indicator_name{separator}metric_name{separator}dataloader_name
|
||||
@ -419,4 +427,4 @@ def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indic
|
||||
names.append(dataloader_name)
|
||||
if len(names) == 0:
|
||||
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.")
|
||||
return separator.join(names)
|
||||
return separator.join(names)
|
||||
|
@ -122,7 +122,8 @@ class Trainer(TrainerEventTrigger):
|
||||
注意如果 model_device 为 None,那么 data_device 不会起作用;
|
||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;
|
||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
||||
use_dist_sampler: 表示在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
|
||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
|
||||
use_eval_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
|
||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
|
||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
|
||||
@ -211,12 +212,6 @@ class Trainer(TrainerEventTrigger):
|
||||
total_batches=None
|
||||
)
|
||||
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", True)
|
||||
if use_dist_sampler:
|
||||
_dist_sampler = "dist"
|
||||
else:
|
||||
_dist_sampler = None
|
||||
|
||||
""" 设置内部的 Evaluator """
|
||||
if metrics is None and evaluate_dataloaders is not None:
|
||||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")
|
||||
@ -224,6 +219,18 @@ class Trainer(TrainerEventTrigger):
|
||||
if metrics is not None and evaluate_dataloaders is None:
|
||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")
|
||||
|
||||
self.metrics = metrics
|
||||
self.validate_every = evaluate_every
|
||||
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
|
||||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
|
||||
if use_dist_sampler:
|
||||
_dist_sampler = "dist"
|
||||
else:
|
||||
_dist_sampler = None
|
||||
|
||||
self.evaluator = None
|
||||
self.monitor = monitor
|
||||
self.larger_better = larger_better
|
||||
@ -241,16 +248,10 @@ class Trainer(TrainerEventTrigger):
|
||||
output_mapping=output_mapping,
|
||||
fp16=fp16,
|
||||
verbose=0,
|
||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", use_dist_sampler),
|
||||
use_dist_sampler=kwargs.get("use_eval_dist_sampler", None),
|
||||
progress_bar=kwargs.get('progress_bar', 'auto')
|
||||
)
|
||||
|
||||
self.metrics = metrics
|
||||
self.validate_every = evaluate_every
|
||||
|
||||
self.driver.setup()
|
||||
self.driver.barrier()
|
||||
|
||||
if train_fn is not None and not isinstance(train_fn, str):
|
||||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
|
||||
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
|
||||
@ -753,7 +754,7 @@ class Trainer(TrainerEventTrigger):
|
||||
"""
|
||||
|
||||
if (self.global_forward_batches + 1) % self.accumulation_steps != 0:
|
||||
_no_sync_context = self.driver.get_no_sync_context()
|
||||
_no_sync_context = self.driver.get_model_no_sync_context()
|
||||
else:
|
||||
_no_sync_context = nullcontext
|
||||
|
||||
|
@ -199,9 +199,10 @@ class Driver(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.")
|
||||
|
||||
def get_no_sync_context(self):
|
||||
def get_model_no_sync_context(self):
|
||||
r"""
|
||||
返回一个用于关闭多进程之间互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,单卡的 driver 不需要;
|
||||
返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数,
|
||||
单卡的 driver 不需要;
|
||||
|
||||
:return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象;
|
||||
"""
|
||||
@ -357,6 +358,8 @@ class Driver(ABC):
|
||||
r"""
|
||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行;
|
||||
仅在多分布式训练场景中有使用。
|
||||
|
||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。
|
||||
"""
|
||||
|
||||
def is_distributed(self) -> bool:
|
||||
|
@ -82,7 +82,7 @@ class JittorMPIDriver(JittorDriver):
|
||||
def is_global_zero(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
def get_no_sync_context(self):
|
||||
def get_model_no_sync_context(self):
|
||||
return self.model.no_sync
|
||||
|
||||
def unwrap_model(self):
|
||||
|
@ -403,7 +403,7 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
def is_global_zero(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
def get_no_sync_context(self):
|
||||
def get_model_no_sync_context(self):
|
||||
return self.model.no_sync
|
||||
|
||||
def unwrap_model(self):
|
||||
|
@ -5,7 +5,6 @@ import socket
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
from typing import List, Optional, Union, Dict, Tuple, Callable
|
||||
from functools import partial
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
@ -29,7 +28,7 @@ from fastNLP.core.drivers.utils import distributed_open_proc
|
||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
|
||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
|
||||
|
||||
@ -511,7 +510,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
def is_global_zero(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
def get_no_sync_context(self):
|
||||
def get_model_no_sync_context(self):
|
||||
# 注意此时的 model 是 "DistributedDataParallel" 对象;
|
||||
return self.model.no_sync
|
||||
|
||||
@ -526,7 +525,8 @@ class TorchDDPDriver(TorchDriver):
|
||||
return self.local_rank
|
||||
|
||||
def barrier(self):
|
||||
torch.distributed.barrier(async_op=True)
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
|
||||
torch.distributed.barrier(async_op=True)
|
||||
|
||||
def is_distributed(self):
|
||||
return True
|
||||
@ -544,6 +544,8 @@ class TorchDDPDriver(TorchDriver):
|
||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
|
||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
|
||||
return
|
||||
return fastnlp_torch_broadcast_object(obj, src, device=self.data_device, group=group)
|
||||
|
||||
def all_gather(self, obj, group) -> List:
|
||||
@ -569,6 +571,8 @@ class TorchDDPDriver(TorchDriver):
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
|
||||
return [obj]
|
||||
return fastnlp_torch_all_gather(obj, group=group)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import io
|
||||
import pickle
|
||||
import os
|
||||
_pickler = pickle.Pickler
|
||||
_unpickler = pickle.Unpickler
|
||||
from typing import Any, List
|
||||
@ -7,6 +8,7 @@ from typing import Any, List
|
||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.envs.env import FASTNLP_NO_SYNC
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
@ -34,47 +36,15 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
)
|
||||
|
||||
|
||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
|
||||
def fastnlp_torch_gather_object(obj, dst=0, group=DEFAULT_TORCH_GROUP):
|
||||
"""
|
||||
从其它 rank gather 东西到 dst rank 。
|
||||
|
||||
Gathers picklable objects from the whole group in a single process.
|
||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the
|
||||
object must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
obj (Any): Input object. Must be picklable.
|
||||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
|
||||
should be correctly sized as the size of the group for this
|
||||
collective and will contain the output. Must be ``None`` on non-dst
|
||||
ranks. (default is ``None``)
|
||||
dst (int, optional): Destination rank. (default is 0)
|
||||
group: (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. On the ``dst`` rank, ``object_gather_list`` will contain the
|
||||
output of the collective.
|
||||
|
||||
.. note:: Note that this API differs slightly from the gather collective
|
||||
since it does not provide an async_op handle and thus will be a blocking
|
||||
call.
|
||||
|
||||
.. note:: Note that this API is not supported when using the NCCL backend.
|
||||
|
||||
.. warning::
|
||||
:func:`gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.gather_object(
|
||||
>>> fastnlp_torch_gather_object(
|
||||
gather_objects[dist.get_rank()],
|
||||
output if dist.get_rank() == 0 else None,
|
||||
dst=0
|
||||
@ -82,7 +52,20 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAU
|
||||
>>> # On rank 0
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
|
||||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
|
||||
:param dst: 目标的 rank 。
|
||||
:param group: 在哪个 group 执行该函数。
|
||||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.get_rank() == dst:
|
||||
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||
else:
|
||||
object_gather_list = None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
|
||||
@ -212,6 +195,9 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) -
|
||||
:param group:
|
||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
if isinstance(obj, torch.Tensor):
|
||||
@ -232,12 +218,18 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
|
||||
"""
|
||||
将 src 上的 obj 对象广播到其它 rank 上。
|
||||
|
||||
:param obj:
|
||||
:param src:
|
||||
:param obj: 需要发送的对象
|
||||
:param src: 从哪里发出。
|
||||
:param device:
|
||||
:param group:
|
||||
:param group: 属于哪个通信 group
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
if src == dist.get_rank(group):
|
||||
return obj
|
||||
else:
|
||||
return None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
cur_rank = dist.get_rank(group)
|
||||
@ -289,50 +281,23 @@ def all_gather_object(object_list, obj, group=None):
|
||||
"""
|
||||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。
|
||||
|
||||
Gathers picklable objects from the whole group into a list. Similar to
|
||||
:func:`all_gather`, but Python objects can be passed in. Note that the object
|
||||
must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
object_list (list[Any]): Output list. It should be correctly sized as the
|
||||
size of the group for this collective and will contain the output.
|
||||
object (Any): Pickable Python object to be broadcast from current process.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. If the calling rank is part of this group, the output of the
|
||||
collective will be populated into the input ``object_list``. If the
|
||||
calling rank is not part of the group, the passed in ``object_list`` will
|
||||
be unmodified.
|
||||
|
||||
.. note:: Note that this API differs slightly from the :func:`all_gather`
|
||||
collective since it does not provide an ``async_op`` handle and thus
|
||||
will be a blocking call.
|
||||
|
||||
.. note:: For NCCL-based processed groups, internal tensor representations
|
||||
of objects must be moved to the GPU device before communication takes
|
||||
place. In this case, the device used is given by
|
||||
``torch.cuda.current_device()`` and it is the user's responsiblity to
|
||||
ensure that this is set so that each rank has an individual GPU, via
|
||||
``torch.cuda.set_device()``.
|
||||
|
||||
.. warning::
|
||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
|
||||
>>> all_gather_object(output, gather_objects[dist.get_rank()])
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
|
||||
:param object_list:
|
||||
:param obj:
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
|
@ -35,7 +35,6 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
|
||||
"'jittor', 'paddle', 'fleet'].")
|
||||
|
||||
|
||||
|
||||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None):
|
||||
"""
|
||||
使用 command 通过 subprocess.Popen 开启新的进程。
|
||||
@ -60,30 +59,3 @@ def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:
|
||||
err_f = open(output_from_new_proc + f'/{rank}_err.log', 'w')
|
||||
proc = subprocess.Popen(command, env=env_copy, stdout=std_f, stderr=err_f)
|
||||
return proc
|
||||
|
||||
|
||||
def load_model(filepath: Union[str, Path], backend: str = "torch", **kwargs):
|
||||
r"""
|
||||
对应 `load_model`,用来帮助用户加载之前通过 `load_model` 所保存的模型;
|
||||
|
||||
:param filepath: 加载的文件的位置;
|
||||
:param backend: 使用哪种 backend 来加载该 filepath, 目前支持 ["torch", "paddle", "jittor"] 。
|
||||
"""
|
||||
|
||||
if filepath is None:
|
||||
raise ValueError("Parameter `path` can not be None.")
|
||||
|
||||
assert backend is not None, "Parameter `backend` can not be None."
|
||||
|
||||
if backend == "torch":
|
||||
import torch
|
||||
_res = torch.load(filepath)
|
||||
return _res
|
||||
elif backend == "jittor":
|
||||
raise NotImplementedError
|
||||
elif backend == "paddle":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Parameter `backend` could only be one of these values: ['torch', 'jittor', 'paddle']")
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@ __all__ = [
|
||||
'TorchBackend',
|
||||
'SpanFPreRecMetric',
|
||||
'ClassifyFPreRecMetric',
|
||||
'func_post_proc'
|
||||
]
|
||||
|
||||
from .metric import Metric
|
||||
@ -15,4 +14,3 @@ from .accuracy import Accuracy
|
||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
|
||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric
|
||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
|
||||
from .utils import func_post_proc
|
||||
|
@ -13,15 +13,22 @@ from fastNLP.core.utils.utils import seq_len_to_mask
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
|
||||
"""
|
||||
计算 准确率 的 metric 。
|
||||
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
|
||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
|
||||
"""
|
||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
|
||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
r"""
|
||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
|
||||
get_metric 函数将根据 evaluate 函数累计的评价指标统计量来计算最终的评价结果.
|
||||
|
||||
:return dict evaluate_result: {"acc": float}
|
||||
"""
|
||||
|
@ -3,35 +3,32 @@ __all__ = [
|
||||
]
|
||||
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from collections import Counter
|
||||
import warnings
|
||||
|
||||
from .metric import Metric
|
||||
from .backend import Backend
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.utils.utils import seq_len_to_mask
|
||||
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
from .utils import _compute_f_pre_rec
|
||||
|
||||
|
||||
class ClassifyFPreRecMetric(Metric):
|
||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0,
|
||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
|
||||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
|
||||
aggregate_when_get_metric: bool = False) -> None:
|
||||
aggregate_when_get_metric: bool = None) -> None:
|
||||
"""
|
||||
|
||||
:param tag_vocab:
|
||||
:param ignore_labels:
|
||||
:param only_gross:
|
||||
:param f_type:
|
||||
:param beta:
|
||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
|
||||
"""
|
||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
|
||||
aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
@ -47,32 +44,15 @@ class ClassifyFPreRecMetric(Metric):
|
||||
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._tp = {}
|
||||
self._fp = {}
|
||||
self._fn = {}
|
||||
if tag_vocab:
|
||||
for word, _ in tag_vocab:
|
||||
word = word.lower()
|
||||
if word != 'o':
|
||||
word = word[2:]
|
||||
if word in self._true_positives:
|
||||
continue
|
||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
elif num_class > 0:
|
||||
for word in range(num_class):
|
||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
|
||||
backend=backend)
|
||||
else:
|
||||
raise ValueError()
|
||||
self._tp = Counter()
|
||||
self._fp = Counter()
|
||||
self._fn = Counter()
|
||||
|
||||
def reset(self):
|
||||
# 由于不是 element 了,需要自己手动清零一下
|
||||
self._tp.clear()
|
||||
self._fp.clear()
|
||||
self._fn.clear()
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
r"""
|
||||
@ -81,10 +61,22 @@ class ClassifyFPreRecMetric(Metric):
|
||||
:return dict evaluate_result: {"acc": float}
|
||||
"""
|
||||
evaluate_result = {}
|
||||
|
||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
|
||||
if self.aggregate_when_get_metric:
|
||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
|
||||
tps, fps, fns = zip(*ls)
|
||||
_tp, _fp, _fn = Counter(), Counter(), Counter()
|
||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
|
||||
for _c in cs:
|
||||
c.update(_c)
|
||||
else:
|
||||
_tp, _fp, _fn = self._tp, self._fp, self._tp
|
||||
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._fn.keys())
|
||||
tags.update(set(self._fp.keys()))
|
||||
tags.update(set(self._tp.keys()))
|
||||
tags = set(_fn.keys())
|
||||
tags.update(set(_fp.keys()))
|
||||
tags.update(set(_tp.keys()))
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
@ -93,9 +85,9 @@ class ClassifyFPreRecMetric(Metric):
|
||||
tag_name = self.tag_vocab.to_word(tag)
|
||||
else:
|
||||
tag_name = int(tag)
|
||||
tp = self._tp[tag].get_scalar()
|
||||
fn = self._fn[tag].get_scalar()
|
||||
fp = self._fp[tag].get_scalar()
|
||||
tp = _tp[tag]
|
||||
fn = _fn[tag]
|
||||
fp = _fp[tag]
|
||||
if tp == fn == fp == 0:
|
||||
continue
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
|
||||
@ -116,10 +108,7 @@ class ClassifyFPreRecMetric(Metric):
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(val.get_scalar() for val in self._tp.values()),
|
||||
sum(val.get_scalar() for val in self._fn.values()),
|
||||
sum(val.get_scalar() for val in self._fp.values()))
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
|
@ -35,6 +35,8 @@ class Element:
|
||||
|
||||
"""
|
||||
self._check_value_initialized()
|
||||
if self.aggregate_method is None: # 如果没有 aggregate 则不进行聚合。
|
||||
return
|
||||
try:
|
||||
self._value = self.backend.aggregate(self._value, self.aggregate_method)
|
||||
except AggregateMethodError as e:
|
||||
|
@ -14,13 +14,13 @@ from fastNLP.core.metrics.element import Element
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None):
|
||||
"""
|
||||
|
||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。
|
||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
|
||||
"""
|
||||
self.backend = AutoBackend(backend)
|
||||
self._updated = False
|
||||
@ -43,7 +43,7 @@ class Metric:
|
||||
|
||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。
|
||||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值
|
||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。
|
||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。如果设置为 None 则表示该 element 不进行聚合。
|
||||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为
|
||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。
|
||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入
|
||||
|
@ -4,12 +4,12 @@ __all__ = [
|
||||
|
||||
from typing import Union, List, Optional
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from collections import Counter
|
||||
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
from fastNLP.core.metrics.metric import Metric
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from .utils import _compute_f_pre_rec
|
||||
|
||||
|
||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
|
||||
@ -199,26 +199,11 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
|
||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
|
||||
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
|
||||
|
||||
class SpanFPreRecMetric(Metric):
|
||||
|
||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
|
||||
only_gross: bool = True, f_type='micro',
|
||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None:
|
||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None:
|
||||
r"""
|
||||
|
||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
|
||||
@ -234,7 +219,7 @@ class SpanFPreRecMetric(Metric):
|
||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。
|
||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。
|
||||
"""
|
||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric):
|
||||
self.only_gross = only_gross
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._true_positives = {}
|
||||
self._false_positives = {}
|
||||
self._false_negatives = {}
|
||||
for word, _ in tag_vocab:
|
||||
word = word.lower()
|
||||
if word != 'o':
|
||||
word = word[2:]
|
||||
if word in self._true_positives:
|
||||
continue
|
||||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
|
||||
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend)
|
||||
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend)
|
||||
self._tp = Counter()
|
||||
self._fp = Counter()
|
||||
self._fn = Counter()
|
||||
|
||||
def reset(self):
|
||||
self._tp.clear()
|
||||
self._fp.clear()
|
||||
self._fn.clear()
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
evaluate_result = {}
|
||||
|
||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
|
||||
if self.aggregate_when_get_metric:
|
||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
|
||||
tps, fps, fns = zip(*ls)
|
||||
_tp, _fp, _fn = Counter(), Counter(), Counter()
|
||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
|
||||
for _c in cs:
|
||||
c.update(_c)
|
||||
else:
|
||||
_tp, _fp, _fn = self._tp, self._fp, self._tp
|
||||
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._false_negatives.keys())
|
||||
tags.update(self._false_positives.keys())
|
||||
tags.update(self._true_positives.keys())
|
||||
tags = set(_fn.keys())
|
||||
tags.update(_fp.keys())
|
||||
tags.update(_tp.keys())
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
for tag in tags:
|
||||
tp = self._true_positives[tag].get_scalar()
|
||||
fn = self._false_negatives[tag].get_scalar()
|
||||
fp = self._false_positives[tag].get_scalar()
|
||||
tp = _tp[tag]
|
||||
fn = _fn[tag]
|
||||
fp = _fp[tag]
|
||||
if tp == fn == fp == 0:
|
||||
continue
|
||||
|
||||
@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric):
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
tp, fn, fp = [], [], []
|
||||
for val in self._true_positives.values():
|
||||
tp.append(val.get_scalar())
|
||||
for val in self._false_negatives.values():
|
||||
fn.append(val.get_scalar())
|
||||
for val in self._false_positives.values():
|
||||
fp.append(val.get_scalar())
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(tp),
|
||||
sum(fn),
|
||||
sum(fp))
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric):
|
||||
|
||||
for span in pred_spans:
|
||||
if span in gold_spans:
|
||||
self._true_positives[span[0]] += 1
|
||||
self._tp[span[0]] += 1
|
||||
gold_spans.remove(span)
|
||||
else:
|
||||
self._false_positives[span[0]] += 1
|
||||
self._fp[span[0]] += 1
|
||||
for span in gold_spans:
|
||||
self._false_negatives[span[0]] += 1
|
||||
self._fn[span[0]] += 1
|
||||
|
@ -1,5 +1,4 @@
|
||||
__all__ = [
|
||||
'func_post_proc'
|
||||
]
|
||||
|
||||
from typing import Any
|
||||
@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
|
||||
"""
|
||||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
|
||||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。
|
||||
|
||||
:param metric: metric对象
|
||||
:param fn: 作用于 metric 的 accumulate 方法的返回值
|
||||
:param method_name: 一般来说,对于
|
||||
:return: metric
|
||||
"""
|
||||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
|
||||
f"Parameter `metric` must have a {method_name} function."
|
||||
assert callable(fn), "Parameter `fn` must be callable."
|
||||
|
||||
func = getattr(metric, method_name)
|
||||
|
||||
@wraps(func)
|
||||
def wrap_method(*args, **kwargs):
|
||||
res = func(*args, **kwargs)
|
||||
return fn(res)
|
||||
|
||||
wrap_method.__wrapped_by_func_post_proc__ = True
|
||||
setattr(metric, method_name, wrap_method)
|
||||
return metric
|
||||
|
||||
|
||||
class AggregateMethodError(BaseException):
|
||||
def __init__(self, should_have_aggregate_method, only_warn=False):
|
||||
super(AggregateMethodError, self).__init__(self)
|
||||
self.should_have_aggregate_method = should_have_aggregate_method
|
||||
self.only_warn = only_warn
|
||||
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
|
@ -222,14 +222,14 @@ class RandomSampler(ReproducibleSampler):
|
||||
|
||||
|
||||
class SequentialSampler(RandomSampler):
|
||||
def __init__(self, dataset, dist_mode:str='interval', **kwargs):
|
||||
def __init__(self, dataset, **kwargs):
|
||||
"""
|
||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param kwargs:
|
||||
"""
|
||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
|
||||
super().__init__(dataset=dataset, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
|
||||
|
@ -6,8 +6,9 @@ __all__ = [
|
||||
'is_cur_env_distributed',
|
||||
'get_global_rank',
|
||||
'rank_zero_call',
|
||||
'all_rank_call',
|
||||
'get_gpu_count'
|
||||
'all_rank_call_context',
|
||||
'get_gpu_count',
|
||||
'fastnlp_no_sync_context'
|
||||
]
|
||||
|
||||
|
||||
|
@ -7,10 +7,11 @@ __all__ = [
|
||||
'is_cur_env_distributed',
|
||||
'get_global_rank',
|
||||
'rank_zero_call',
|
||||
'all_rank_call'
|
||||
'all_rank_call_context',
|
||||
'fastnlp_no_sync_context'
|
||||
]
|
||||
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC
|
||||
|
||||
|
||||
def is_cur_env_distributed() -> bool:
|
||||
@ -41,24 +42,46 @@ def rank_zero_call(fn: Callable):
|
||||
return a+b
|
||||
rank_zero_call(add)(1, 2)
|
||||
|
||||
同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何
|
||||
意义。
|
||||
|
||||
:param fn: 需要包裹的可执行的函数。
|
||||
:return:
|
||||
"""
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
return fn(*args, **kwargs)
|
||||
with fastnlp_no_sync_context(level=2):
|
||||
return fn(*args, **kwargs)
|
||||
return None
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def all_rank_call():
|
||||
def fastnlp_no_sync_context(level=2):
|
||||
"""
|
||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效;
|
||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。
|
||||
|
||||
:param int level: 可选 [0, 1, 2]
|
||||
:return:
|
||||
"""
|
||||
old_level = os.environ.get(FASTNLP_NO_SYNC, None)
|
||||
os.environ[FASTNLP_NO_SYNC] = f'{level}'
|
||||
yield
|
||||
if old_level is None:
|
||||
os.environ.pop(FASTNLP_NO_SYNC)
|
||||
else:
|
||||
os.environ[FASTNLP_NO_SYNC] = old_level
|
||||
|
||||
|
||||
@contextmanager
|
||||
def all_rank_call_context():
|
||||
"""
|
||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。
|
||||
|
||||
# 使用方式
|
||||
with all_rank_run():
|
||||
with all_rank_call_context():
|
||||
do_something # all rank will do
|
||||
|
||||
:param fn:
|
||||
|
@ -48,6 +48,10 @@ FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"
|
||||
# fastNLP 中初始化deque的默认大小
|
||||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'
|
||||
|
||||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行;
|
||||
# 为 '2' 表示 barrier 与 gather/broadcast 都关闭。
|
||||
FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC'
|
||||
|
||||
# todo 注释 直接使用的变量
|
||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar"
|
||||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar"
|
||||
|
@ -85,7 +85,7 @@ class TestFleetDriverFunction:
|
||||
"""
|
||||
测试 get_no_sync_context 函数
|
||||
"""
|
||||
res = self.driver.get_no_sync_context()
|
||||
res = self.driver.get_model_no_sync_context()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
|
@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric:
|
||||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
|
||||
0, 3, 0, 0, 0, 1, 3, 1])
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5)
|
||||
metric = ClassifyFPreRecMetric(f_type='macro')
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
f1_score = 0.1882051282051282
|
||||
@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric:
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5)
|
||||
metric = ClassifyFPreRecMetric(f_type='micro')
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
f1_score = 0.21875
|
||||
@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric:
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5)
|
||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
ground_truth = {
|
||||
@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric:
|
||||
})
|
||||
metric_kwargs = {
|
||||
'f_type': f_type,
|
||||
'num_class': 5,
|
||||
'only_gross': False,
|
||||
'aggregate_when_get_metric': True
|
||||
}
|
||||
|
@ -102,7 +102,8 @@ class TestSpanFPreRecMetric:
|
||||
# bio tag
|
||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
|
||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
|
||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
|
||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False,
|
||||
aggregate_when_get_metric=True)
|
||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
|
||||
-0.3782, 0.8240],
|
||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
|
||||
|
@ -1,32 +0,0 @@
|
||||
import unittest
|
||||
from fastNLP.core.metrics.utils import func_post_proc
|
||||
|
||||
|
||||
class Metric:
|
||||
def accumulate(self, x, y):
|
||||
return x, y
|
||||
|
||||
def compute(self, x, y):
|
||||
return x, y
|
||||
|
||||
|
||||
class TestMetricUtil(unittest.TestCase):
|
||||
def test_func_post_proc(self):
|
||||
metric = Metric()
|
||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate')
|
||||
self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2))
|
||||
|
||||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate')
|
||||
self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2))
|
||||
|
||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
|
||||
self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2))
|
||||
|
||||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update')
|
||||
self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2))
|
||||
|
||||
def test_check_accumulate_post_special_local_variable(self):
|
||||
metric = Metric()
|
||||
self.assertFalse(hasattr(metric, '__wrapped_by_fn__'))
|
||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
|
||||
self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call
|
||||
from fastNLP.envs.distributed import rank_zero_call, all_rank_call_context
|
||||
from tests.helpers.utils import re_run_current_cmd_for_torch, Capturing, magic_argv_env_context
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class TestTorch:
|
||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
|
||||
# torch.distributed.init_process_group(backend='nccl')
|
||||
# torch.distributed.barrier()
|
||||
with all_rank_call():
|
||||
with all_rank_call_context():
|
||||
with Capturing(no_del=True) as output:
|
||||
write_something()
|
||||
output = output[0]
|
||||
@ -80,7 +80,7 @@ class TestTorch:
|
||||
else:
|
||||
assert '11111' in output
|
||||
|
||||
with all_rank_call():
|
||||
with all_rank_call_context():
|
||||
with Capturing(no_del=True) as output:
|
||||
rank_zero_call(write_other_thing)()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user