添加了metrics

This commit is contained in:
MorningForest 2022-04-08 21:32:26 +08:00
parent ce1b837d13
commit 4bb56616b9
16 changed files with 1661 additions and 0 deletions

View File

@ -0,0 +1,18 @@
__all__ = [
"Metric",
"Accuracy",
'Backend',
'AutoBackend',
'PaddleBackend',
'TorchBackend',
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',
'func_post_proc'
]
from .metric import Metric
from .accuracy import Accuracy
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
from .utils import func_post_proc

View File

@ -0,0 +1,75 @@
__all__ = [
'Accuracy'
]
from typing import Union
import warnings
import numpy as np
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.utils import seq_len_to_mask
class Accuracy(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = True):
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
def get_metric(self) -> dict:
r"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)}
return evaluate_result
def update(self, pred, target, seq_len=None):
r"""
evaluate函数将针对一个批次的预测结果做评价指标的累计
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
如果mask也被传进来的话seq_len会被忽略.
"""
# 为了兼容不同框架我们将输入变量全部转为numpy类型来进行计算。
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if seq_len is not None:
seq_len = self.tensor2numpy(seq_len)
if seq_len is not None and target.ndim > 1:
max_len = target.shape[1]
masks = seq_len_to_mask(seq_len, max_len)
else:
masks = None
if pred.ndim == target.ndim:
if np.prod(pred.shape) != np.prod(target.shape):
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
f" while target have shape:{target.shape}, "
f"pred have shape: {target.shape}")
elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.")
if masks is not None:
self.total += masks.sum().item()
self.correct += ((pred == target) * masks).sum().item()
else:
self.total += np.prod(list(pred.shape)).item()
self.correct += (target == pred).sum().item()

View File

@ -0,0 +1,12 @@
__all__ = [
'Backend',
'AutoBackend',
'TorchBackend',
'PaddleBackend'
]
from .backend import Backend
from .auto_backend import AutoBackend
from .torch_backend.backend import TorchBackend
from .paddle_backend.backend import PaddleBackend

View File

@ -0,0 +1,75 @@
from typing import Union
from .backend import Backend
from .torch_backend.backend import TorchBackend
from .paddle_backend.backend import PaddleBackend
from .jittor_backend.backend import JittorBackend
class AutoBackend(Backend):
"""
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的
"""
def __init__(self, backend: Union[str, Backend, None]):
super(AutoBackend, self).__init__()
if backend != 'auto':
self._convert_backend(backend)
def _convert_backend(self, backend):
"""
将AutoBackend转换为合适的Backend对象
"""
if isinstance(backend, Backend):
self.__class__ = backend.__class__
# 如果是str直接选择就好了
elif backend == 'torch':
self.__class__ = TorchBackend
elif backend == 'paddle':
self.__class__ = PaddleBackend
elif backend == 'jittor':
self.__class__ = JittorBackend
elif backend is None:
# 不用做任何事情就可以初始化了
pass
else:
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.")
self._specified = True
def choose_real_backend(self, args):
assert not self.is_specified(), "This method should not be called after backend has been specified. " \
"This must be a bug, please report."
types = []
for arg in args:
types.append(str(type(arg)))
torch_types = []
jittor_types = []
paddle_types = []
for type_name in types:
if 'torch' in type_name:
torch_types.append(type_name)
if 'paddle' in type_name:
paddle_types.append(type_name)
if 'jittor' in type_name:
jittor_types.append(type_name)
# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
backend = 'torch'
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0:
backend = 'jittor'
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0:
backend = 'paddle'
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
# 直接使用default的backend就好了
backend = None
else:
types = list(set(torch_types + jittor_types + paddle_types))
raise RuntimeError(
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of "
f"using backend=auto.")
self._convert_backend(backend)

View File

@ -0,0 +1,75 @@
from ..utils import AggregateMethodError
class Backend:
"""
Backend 及其子类的所有方法都必须是无状态的
"""
def __init__(self):
self._specified = False
def aggregate(self, tensor, method: str):
"""
聚集结果并根据method计算后返回结果
"""
if method is not None:
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True)
return tensor
def create_tensor(self, value: float):
"""
创建tensor并且填入value作为值
"""
return value
def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value
"""
return value
def get_scalar(self, tensor) -> float:
"""
tensor的saclar值
:param tensor:
:return:
"""
return tensor
def is_specified(self) -> bool:
"""
判断是否是某种框架的backend
:return:
"""
return self._specified
def tensor2numpy(self, tensor):
"""
将tensor转为numpy
:param tensor:
:return:
"""
return tensor
def move_tensor_to_device(self, tensor, device):
"""
"""
return tensor
def all_gather_object(self, obj, group=None):
"""
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 返回一个 list 对象里面依次为各个 rank 对应的 obj
:param obj:
:param group:
:return:
"""
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.")

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,72 @@
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.metrics.backend import Backend
if _NEED_IMPORT_JITTOR:
import jittor
class JittorBackend(Backend):
def __init__(self):
super(JittorBackend, self).__init__()
self._specified = True
def aggregate(self, tensor, method: str):
"""
聚集结果并根据method计算后返回结果
"""
return tensor
def create_tensor(self, value: float):
"""
创建tensor并且填入value作为值
"""
value = jittor.Var(value)
return value
def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value
"""
value = jittor.full_like(tensor, value)
return value
def get_scalar(self, tensor) -> float:
"""
tensor的saclar值
:param tensor:
:return:
"""
return tensor.item()
def is_specified(self) -> bool:
"""
判断是否是某种框架的backend
:return:
"""
return self._specified
def tensor2numpy(self, tensor):
"""
将tensor转为numpy
:param tensor:
:return:
"""
if isinstance(tensor, jittor.Var):
return tensor.detach().numpy()
elif isinstance(tensor, np.array):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
def move_tensor_to_device(self, tensor, device):
"""
jittor的没有转移设备的函数因此该函数实际上无效
"""
return tensor

View File

@ -0,0 +1,5 @@
__all__ = [
'PaddleBackend'
]
from .backend import Backend as PaddleBackend

View File

@ -0,0 +1,126 @@
from typing import List, Optional, Any
import numpy as np
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.paddle_utils import paddle_to
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.core.utils import is_in_paddle_dist
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
from paddle.fluid.dygraph import parallel_helper
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result, group)
return gathered_result
class PaddleBackend(Backend):
def __init__(self):
super().__init__()
self._specified = True
def aggregate(self, tensor, method: str):
"""
聚集结果并根据method计算后返回结果
"""
if isinstance(tensor, paddle.Tensor):
if parallel_helper._is_parallel_ctx_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
if isinstance(tensor[0], paddle.Tensor):
tensor = paddle.stack(tensor)
# 第一步, aggregate结果
if method == 'sum':
tensor = paddle.sum(tensor, dim=0)
elif method == 'mean':
tensor = paddle.mean(tensor, dim=0)
elif method == 'max':
tensor, _ = paddle.max(tensor, dim=0)
elif method == 'min':
tensor, _ = paddle.min(tensor, dim=0)
else:
raise AggregateMethodError(should_have_aggregate_method=False)
return tensor
def create_tensor(self, value: float):
"""
创建tensor并且填入value作为值
"""
tensor = paddle.ones((1,)).fill_(value)
return tensor
def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value
"""
tensor.fill_(value)
return tensor
def get_scalar(self, tensor) -> float:
return tensor.item()
def tensor2numpy(self, tensor) -> np.array:
if isinstance(tensor, paddle.Tensor):
return tensor.cpu().detach().numpy()
elif isinstance(tensor, np.array):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""
聚合 group 中所有的 result由于不同 group result 大小不同因此在适当的时候需要进行 padding
"""
# TODO check 正确性
if group is None:
group = paddle.distributed.get_group(0)
world_size = group.nranks
paddle.distributed.barrier(group=group)
# 张量为 标量的情况简单地gather就好
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
# 获得 result 的 shape
local_size = paddle.to_tensor(result.shape)
# 将 group 中所有 result 的大小聚合在一起
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)]
paddle.distributed.all_gather(local_sizes, local_size, group=group)
# 堆叠后,计算出 shape 每一维度的最大值
max_size = paddle.stack(local_sizes).max(axis=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
# 如果所有的结果大小相同,那么可以直接聚合
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)
# 否则padding 与最大的张量对齐
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = paddle.nn.functional.pad(result, pad_dims)
# 重新进行聚合
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)]
paddle.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result
def move_tensor_to_device(self, tensor, device):
# TODO 如果在这里处理的话会不会在别的地方引起bug
if is_in_paddle_dist():
device = get_device_from_visible(device)
return paddle_to(tensor, device)

View File

@ -0,0 +1,6 @@
__all__ = [
'TorchBackend'
]
from .backend import Backend as TorchBackend

View File

@ -0,0 +1,154 @@
from typing import Any, List, Optional
import numpy as np
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
import torch.nn.functional as F
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
dist.all_gather(gathered_result, result, group)
return gathered_result
class TorchBackend(Backend):
def __init__(self):
super().__init__()
self._specified = True
def aggregate(self, tensor, method: str):
"""
聚集结果并根据method计算后返回结果
"""
if isinstance(tensor, torch.Tensor):
if dist.is_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
if isinstance(tensor[0], torch.Tensor):
tensor = torch.stack(tensor)
# 第一步, aggregate结果
if method == 'sum':
tensor = torch.sum(tensor, dim=0)
elif method == 'mean':
tensor = torch.mean(tensor, dim=0)
elif method == 'max':
tensor, _ = torch.max(tensor, dim=0)
elif method == 'min':
tensor, _ = torch.min(tensor, dim=0)
else:
raise AggregateMethodError(should_have_aggregate_method=False)
return tensor
def create_tensor(self, value: float):
"""
创建tensor并且填入value作为值
"""
tensor = torch.ones(1).fill_(value)
return tensor
def fill_value(self, tensor, value: float):
"""
将tensor的值设置为value
"""
tensor.fill_(value)
return tensor
def get_scalar(self, tensor) -> float:
return tensor.item()
@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if group is None:
group = dist.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = dist.get_world_size(group)
dist.barrier(group=group)
# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = torch.nn.functional.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result
def tensor2numpy(self, tensor) -> np.array:
"""
将对应的tensor转为numpy对象
"""
if isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy()
elif isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, (float, int)):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
@staticmethod
def is_distributed() -> bool:
"""
:return:
"""
return dist.is_available() and dist.is_initialized()
def move_tensor_to_device(self, tensor, device):
return tensor.to(device)
def all_gather_object(self, obj, group=None) -> List:
if self.is_distributed():
obj_list = fastnlp_torch_all_gather(obj, group=group)
return obj_list
return [obj]

View File

@ -0,0 +1,142 @@
__all__ = [
'ClassifyFPreRecMetric'
]
from typing import Union, List
from collections import defaultdict
from functools import partial
import warnings
from .metric import Metric
from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.utils import seq_len_to_mask
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec
class ClassifyFPreRecMetric(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False,
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro', beta=1) -> None:
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.only_gross = only_gross
self.tag_vocab = tag_vocab
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum')),\
defaultdict(partial(self.register_element, aggregate_method='sum'))
def get_metric(self) -> dict:
r"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._fn.keys())
tags.update(set(self._fp.keys()))
tags.update(set(self._tp.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
if self.tag_vocab is not None:
tag_name = self.tag_vocab.to_word(tag)
else:
tag_name = int(tag)
tp = self._tp[tag]
fn = self._fn[tag]
fp = self._fp[tag]
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag_name)
pre_key = 'pre-{}'.format(tag_name)
rec_key = 'rec-{}'.format(tag_name)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec
if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(self._tp.values()),
sum(self._fn.values()),
sum(self._fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)
return evaluate_result
def update(self, pred, target, seq_len=None):
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if seq_len is not None:
seq_len = self.tensor2numpy(seq_len)
if seq_len is not None and target.ndim > 1:
max_len = target.ndim[-1]
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
else:
masks = None
if pred.ndim == target.ndim:
if len(pred.flatten()) != len(target.flatten()):
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
f" while target have element numbers:{len(pred.flatten())}, "
f"pred have element numbers: {len(target.flatten())}")
pass
elif len(pred.ndim) == len(target.ndim) + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and len(target.ndim) > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"when pred have "
f"size:{pred.ndim}, target should have size: {pred.ndim} or "
f"{pred.ndim[:-1]}, got {target.ndim}.")
if masks is not None:
target = target * masks
pred = pred * masks
target_idxes = set(target.reshape(-1).tolist())
for target_idx in target_idxes:
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()

View File

@ -0,0 +1,281 @@
__all__ = [
'Element'
]
import os
from .backend import Backend, AutoBackend
from fastNLP.core.log import logger
from .utils import AggregateMethodError
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
class Element:
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
self.init_value = value
self.aggregate_method = aggregate_method
self.name = name
if backend == 'auto':
raise RuntimeError("You have to specify the backend.")
elif isinstance(backend, AutoBackend):
self.backend = backend
else:
self.backend = AutoBackend(backend)
if self.backend.is_specified():
value = self.backend.create_tensor(self.init_value)
else:
value = None
self._value = value
self.device = None
def aggregate(self):
"""
自动aggregate对应的元素
"""
try:
self._value = self.backend.aggregate(self._value, self.aggregate_method)
except AggregateMethodError as e:
msg = 'If you see this message, please report a bug.'
if self.name and e.should_have_aggregate_method:
msg = f"Element:{self.name} has no specified `aggregate_method`."
elif e.should_have_aggregate_method:
msg = "Element has no specified `aggregate_method`."
elif self.name and not e.should_have_aggregate_method:
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
elif not e.should_have_aggregate_method:
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
if e.only_warn:
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
logger.warning(msg)
self._value = self.backend.aggregate(self._value, method=None)
else:
raise RuntimeError(msg)
def reset(self):
if self.backend.is_specified():
self._value = self.backend.fill_value(self._value, self.init_value)
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self._check_value_initialized()
self._value = value
@value.getter
def value(self):
self._check_value_initialized()
return self._value
def get_scalar(self) -> float:
return self.backend.get_scalar(self._value)
def fill_value(self, value):
self._value = self.backend.fill_value(self._value, value)
def to(self, device):
# device这里如何处理呢
if self._value is not None:
self._value = self.backend.move_tensor_to_device(self._value, device)
self.device = device
def _check_value_initialized(self):
if self._value is None:
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \
f"initialization."
self._value = self.backend.create_tensor(self.init_value)
if self.device is not None:
self.to(device=self.device)
def _check_value_when_call(self):
if self.value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
def __add__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self
def __radd__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self
def __sub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self
def __rsub__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value -= other.value
else:
self.value -= other
return self
def __mul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self
def __imul__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value *= other.value
else:
self.value *= other
return self
def __floordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self
def __rfloordiv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value //= other.value
else:
self.value //= other
return self
def __truediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
def __rtruediv__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
def __mod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value %= other.value
else:
self.value %= other
return self
def __rmod__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value /= other.value
else:
self.value /= other
return self
def __pow__(self, other, modulo=None):
self._check_value_when_call()
if modulo is None:
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
else:
if isinstance(other, Element):
self.value = pow(self.value, other.value, modulo)
else:
self.value = pow(self.value, other, modulo)
return self
def __rpow__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
self.value **= other.value
else:
self.value **= other
return self
def __lt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value < other.value
else:
return self.value < other
def __le__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value <= other.value
else:
return self.value <= other
def __eq__(self, other):
self._check_value_when_call()
if isinstance(other, Element):
return self.value == other.value
else:
return self.value == other
def __ne__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value != other.value
else:
return self.value != other
def __ge__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value >= other.value
else:
return self.value >= other
def __gt__(self, other) -> bool:
self._check_value_when_call()
if isinstance(other, Element):
return self.value > other.value
else:
return self.value > other
def __str__(self):
return str(self.value)
def __repr__(self):
return str(self.value)
def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性实现该方法后用户可以在FDataLoader实例化后使用apply等dataset的方法
:param item:
:return:
"""
try:
if self._value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
return getattr(self._value, item)
except AttributeError as e:
raise e

View File

@ -0,0 +1,184 @@
__all__ = [
'Metric'
]
from abc import abstractmethod
from typing import Union
import functools
from contextlib import contextmanager
import numpy as np
from fastNLP.core.metrics.backend import Backend, AutoBackend
from fastNLP.core.metrics.element import Element
class Metric:
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
"""
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 大部分情况下直接使用 auto 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义
"""
self.backend = AutoBackend(backend)
self._updated = False
self.get_metric = self._sync_get_metric(self.get_metric)
self.update = self._wrap_update(self.update)
self.reset = self._wrap_auto_reset_elements(self.reset)
self.aggregate_when_get_metric = aggregate_when_get_metric
self._cannot_change_element = False
self._elements = {}
@property
def elements(self) -> dict:
return self._elements
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
"""
注册一个 element 对象注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用可以认为该对象即为对应 backend
tensor 直接进行加减乘除计算即可
注意如果想使得该 metric 可自动扩展到多卡的情况请一定申明 aggregate_method
:param name: 当前 element 的名字注册后 Metric 中可以通过 self.{name} 访问该变量
:param value: 初始化的值在调用 Metric.reset() 方法时也将自动设置为该值
:param aggregate_method: 如何聚合多卡上的结果如果为单卡执行该值无意义
:param backend: 使用的 backend Element 的类型会根据 backend 进行实际的初始化例如 backend torch 则该对象为
Torch.tensor 如果backend paddle 则该对象为 paddle.tensor 如果 backend jittor , 则该对象为 jittor.Var
一般情况下直接默认为 auto 就行了fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化例如当传入
的参数中只包含 torch.Tensor 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend torch 只包含
jittor.Var 则认为 backend 这一种 tensor 可以有其它非 tensor 类型的输入则认为 backend jittor 如果没有检测
到任何一种 tensor 就默认使用 float 类型作为 element
:return: 注册的 Element 对象
"""
if backend == 'auto':
backend = self.backend
else:
backend = AutoBackend(backend)
# 当name为None默认为变量取得变量名
if name is None:
name = f'ele_var_{len(self._elements)}'
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
self.elements[name] = element
setattr(self, name, element)
return element
def reset(self):
"""
如果有非 element 的对象需要 reset 的时候在本方法中写下非 element 的reset 方式注册的 element 对象会自动 reset 为初始值
"""
pass
def _wrap_auto_reset_elements(self, reset):
@functools.wraps(reset)
def _wrap_reset(*args, **kwargs):
self._updated = False
for ele in self.elements.values():
ele.reset()
reset(*args, **kwargs)
return _wrap_reset
def _sync_get_metric(self, get_metric):
@functools.wraps(get_metric)
def _wrap_get_metric(*args, **kwargs):
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \
f"get_metric()."
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
results = get_metric(*args, **kwargs)
return results
return _wrap_get_metric
def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if key in self.elements and value is not self.elements[key]:
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}")
object.__setattr__(self, key, value)
def _wrap_update(self, update):
@functools.wraps(update)
def _wrap_update(*args, **kwargs):
self.check_backend(*args, **kwargs)
self._cannot_change_element = True
self._updated = True
return update(*args, **kwargs)
return _wrap_update
def check_backend(self, *args, **kwargs):
if not self.backend.is_specified():
_args = []
for arg in args:
_args.append(arg)
for arg in kwargs.values():
_args.append(arg)
self.backend.choose_real_backend(_args)
@contextmanager
def sync(self, recover=True, aggregate=False):
"""
在这个上下文下 metric 会自动先同步需要同步操作的 element recover True 在退出环境的时候会重新将 element
值恢复到计算前的值
"""
keep_value = {}
if aggregate:
for name, element in self.elements.items():
# 保存过去的值
keep_value[name] = element.get_scalar()
# 聚合结果
element.aggregate()
yield
if recover and aggregate:
for name, element in self.elements.items():
# 恢复结果
if name in keep_value:
element.fill_value(value=keep_value.get(name))
@abstractmethod
def update(self, *args, **kwargs):
raise NotImplementedError()
@abstractmethod
def get_metric(self) -> dict:
raise NotImplementedError()
def set_auto_aggregate_when_get_metric(self, flag: bool):
"""
设置是否在 get_metric 的时候自动 aggregate
"""
self.aggregate_when_get_metric = flag
def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name))
def tensor2numpy(self, tensor) -> np.array:
"""
将tensor向量转为numpy类型变量
:param tensor:
:return:
"""
return self.backend.tensor2numpy(tensor)
def to(self, device):
"""
将所有的 element 变量移动到 device 设备上
:param device:
:return:
"""
for element in self.elements.values():
element.to(device)

View File

@ -0,0 +1,344 @@
__all__ = [
'SpanFPreRecMetric'
]
from typing import Union, List, Optional
import warnings
from collections import defaultdict
from functools import partial
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
r"""
检查vocab中的tag是否与encoding_type是匹配的
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"}即index在前tag在后的dict
:param encoding_type: bio, bmes, bioes, bmeso
:return:
"""
tag_set = set()
unk_token = '<unk>'
pad_token = '<pad>'
if isinstance(tag_vocab, Vocabulary):
unk_token = tag_vocab.unknown
pad_token = tag_vocab.padding
tag_vocab = tag_vocab.idx2word
for idx, tag in tag_vocab.items():
if tag in (unk_token, pad_token):
continue
tag = tag[:1].lower()
tag_set.add(tag)
tags = encoding_type
for tag in tag_set:
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
f"encoding_type."
tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空说明出现了未使用的tag
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.")
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str:
r"""
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"}即index在前tag在后的dict
:return:
"""
tag_set = set()
unk_token = '<unk>'
pad_token = '<pad>'
if isinstance(tag_vocab, Vocabulary):
unk_token = tag_vocab.unknown
pad_token = tag_vocab.padding
tag_vocab = tag_vocab.idx2word
for idx, tag in tag_vocab.items():
if tag in (unk_token, pad_token):
continue
tag = tag[:1].lower()
tag_set.add(tag)
bmes_tag_set = set('bmes')
if tag_set == bmes_tag_set:
return 'bmes'
bio_tag_set = set('bio')
if tag_set == bio_tag_set:
return 'bio'
bmeso_tag_set = set('bmeso')
if tag_set == bmeso_tag_set:
return 'bmeso'
bioes_tag_set = set('bioes')
if tag_set == bioes_tag_set:
return 'bioes'
raise RuntimeError("encoding_type cannot be inferred automatically. Only support "
"'bio', 'bmes', 'bmeso', 'bioes' type.")
def _bmes_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]
def _bmeso_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']
返回[('singer', (1, 4))] (左闭右开区间)
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]
def _bioes_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']
返回[('singer', (1, 4))] (左闭右开区间)
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bioes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bioes_tag, label = tag[:1], tag[2:]
if bioes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bioes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bioes_tag = bioes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]
def _bio_tag_to_spans(tags, ignore_labels=None):
r"""
给定一个tags的lis比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']
返回[('singer', (1, 4))] (左闭右开区间)
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bio_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b':
spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
pass
else:
spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec
class SpanFPreRecMetric(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None,
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro',
beta=1, aggregate_when_get_metric: bool = True,) -> None:
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if encoding_type:
encoding_type = encoding_type.lower()
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
self.encoding_type = encoding_type
else:
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)
if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = _bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans
elif self.encoding_type == 'bioes':
self.tag_to_span_func = _bioes_tag_to_spans
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.")
self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.only_gross = only_gross
self.tag_vocab = tag_vocab
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
def get_metric(self) -> dict:
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec
if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(val.get_scalar() for val in self._true_positives.values()),
sum(val.get_scalar() for val in self._false_negatives.values()),
sum(val.get_scalar() for val in self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)
return evaluate_result
def update(self, pred, target, seq_len: Optional[List] = None) -> None:
r"""update函数将针对一个批次的预测结果做评价指标的累计
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
:param target: [batch, seq_len], 真实值
:param seq_len: [batch] 文本长度标记
:return:
"""
pred = self.tensor2numpy(pred)
target = self.tensor2numpy(target)
if pred.ndim == target.ndim and target.ndim == 2:
pass
elif pred.ndim == target.ndim + 1 and target.ndim == 2:
num_classes = pred.shape[-1]
pred = pred.argmax(axis=-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or "
f"{pred.shape[:-1]}, got {target.ndim}.")
batch_size = pred.shape[0]
pred = pred.tolist()
target = target.tolist()
for i in range(batch_size):
pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])]
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1

View File

@ -0,0 +1,91 @@
__all__ = [
'func_post_proc'
]
from typing import Any
from functools import wraps
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.utils import _module_available
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
if _IS_ALLENNLP_AVAILABLE:
from allennlp.training.metrics import Metric as allennlp_Metric
if _NEED_IMPORT_PADDLE:
from paddle.metric import Metric as paddle_Metric
def _is_torchmetrics_metric(metric: Any) -> bool:
"""
检查输入的对象是否为torchmetrics对象
:param metric:
:return:
"""
if _IS_TORCHMETRICS_AVAILABLE:
return isinstance(metric, torchmetrics_Metric)
else:
return False
def _is_allennlp_metric(metric: Any) -> bool:
"""
检查输入的对象是否为allennlp对象
:param metric:
:return:
"""
if _IS_ALLENNLP_AVAILABLE:
return isinstance(metric, allennlp_Metric)
else:
return False
def _is_paddle_metric(metric: Any) -> bool:
"""
检查输入的对象是否为allennlp对象
:param metric:
:return:
"""
if _NEED_IMPORT_PADDLE:
return isinstance(metric, paddle_Metric)
else:
return False
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
"""
将fn函数作用包裹在 metric 对象的 {method_name} 方法上使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
后再返回注意对 metric {method_name} 函数的修改是 inplace
:param metric: metric对象
:param fn: 作用于 metric accumulate 方法的返回值
:param method_name: 一般来说对于
:return: metric
"""
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
f"Parameter `metric` must have a {method_name} function."
assert callable(fn), "Parameter `fn` must be callable."
func = getattr(metric, method_name)
@wraps(func)
def wrap_method(*args, **kwargs):
res = func(*args, **kwargs)
return fn(res)
wrap_method.__wrapped_by_func_post_proc__ = True
setattr(metric, method_name, wrap_method)
return metric
class AggregateMethodError(BaseException):
def __init__(self, should_have_aggregate_method, only_warn=False):
super(AggregateMethodError, self).__init__(self)
self.should_have_aggregate_method = should_have_aggregate_method
self.only_warn = only_warn