mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
添加了metrics
This commit is contained in:
parent
ce1b837d13
commit
4bb56616b9
18
fastNLP/core/metrics/__init__.py
Normal file
18
fastNLP/core/metrics/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"Accuracy",
|
||||
'Backend',
|
||||
'AutoBackend',
|
||||
'PaddleBackend',
|
||||
'TorchBackend',
|
||||
'SpanFPreRecMetric',
|
||||
'ClassifyFPreRecMetric',
|
||||
'func_post_proc'
|
||||
]
|
||||
|
||||
from .metric import Metric
|
||||
from .accuracy import Accuracy
|
||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
|
||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric
|
||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
|
||||
from .utils import func_post_proc
|
75
fastNLP/core/metrics/accuracy.py
Normal file
75
fastNLP/core/metrics/accuracy.py
Normal file
@ -0,0 +1,75 @@
|
||||
__all__ = [
|
||||
'Accuracy'
|
||||
]
|
||||
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.metrics.metric import Metric
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
from fastNLP.core.utils.utils import seq_len_to_mask
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto',
|
||||
aggregate_when_get_metric: bool = True):
|
||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
|
||||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
r"""
|
||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
|
||||
|
||||
:return dict evaluate_result: {"acc": float}
|
||||
"""
|
||||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)}
|
||||
return evaluate_result
|
||||
|
||||
def update(self, pred, target, seq_len=None):
|
||||
r"""
|
||||
evaluate函数将针对一个批次的预测结果做评价指标的累计
|
||||
|
||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
|
||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
|
||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
|
||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
|
||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]).
|
||||
如果mask也被传进来的话seq_len会被忽略.
|
||||
"""
|
||||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。
|
||||
pred = self.tensor2numpy(pred)
|
||||
target = self.tensor2numpy(target)
|
||||
if seq_len is not None:
|
||||
seq_len = self.tensor2numpy(seq_len)
|
||||
|
||||
if seq_len is not None and target.ndim > 1:
|
||||
max_len = target.shape[1]
|
||||
masks = seq_len_to_mask(seq_len, max_len)
|
||||
else:
|
||||
masks = None
|
||||
|
||||
if pred.ndim == target.ndim:
|
||||
if np.prod(pred.shape) != np.prod(target.shape):
|
||||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
|
||||
f" while target have shape:{target.shape}, "
|
||||
f"pred have shape: {target.shape}")
|
||||
|
||||
elif pred.ndim == target.ndim + 1:
|
||||
pred = pred.argmax(axis=-1)
|
||||
if seq_len is None and target.ndim > 1:
|
||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or "
|
||||
f"{pred.shape[:-1]}, got {target.shape}.")
|
||||
|
||||
if masks is not None:
|
||||
self.total += masks.sum().item()
|
||||
self.correct += ((pred == target) * masks).sum().item()
|
||||
else:
|
||||
self.total += np.prod(list(pred.shape)).item()
|
||||
self.correct += (target == pred).sum().item()
|
12
fastNLP/core/metrics/backend/__init__.py
Normal file
12
fastNLP/core/metrics/backend/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
__all__ = [
|
||||
'Backend',
|
||||
'AutoBackend',
|
||||
'TorchBackend',
|
||||
'PaddleBackend'
|
||||
]
|
||||
|
||||
|
||||
from .backend import Backend
|
||||
from .auto_backend import AutoBackend
|
||||
from .torch_backend.backend import TorchBackend
|
||||
from .paddle_backend.backend import PaddleBackend
|
75
fastNLP/core/metrics/backend/auto_backend.py
Normal file
75
fastNLP/core/metrics/backend/auto_backend.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import Union
|
||||
|
||||
from .backend import Backend
|
||||
from .torch_backend.backend import TorchBackend
|
||||
from .paddle_backend.backend import PaddleBackend
|
||||
from .jittor_backend.backend import JittorBackend
|
||||
|
||||
|
||||
class AutoBackend(Backend):
|
||||
"""
|
||||
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Union[str, Backend, None]):
|
||||
super(AutoBackend, self).__init__()
|
||||
if backend != 'auto':
|
||||
self._convert_backend(backend)
|
||||
|
||||
def _convert_backend(self, backend):
|
||||
"""
|
||||
将AutoBackend转换为合适的Backend对象
|
||||
|
||||
"""
|
||||
if isinstance(backend, Backend):
|
||||
self.__class__ = backend.__class__
|
||||
# 如果是str,直接选择就好了
|
||||
elif backend == 'torch':
|
||||
self.__class__ = TorchBackend
|
||||
elif backend == 'paddle':
|
||||
self.__class__ = PaddleBackend
|
||||
elif backend == 'jittor':
|
||||
self.__class__ = JittorBackend
|
||||
elif backend is None:
|
||||
# 不用做任何事情就可以初始化了
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.")
|
||||
self._specified = True
|
||||
|
||||
def choose_real_backend(self, args):
|
||||
assert not self.is_specified(), "This method should not be called after backend has been specified. " \
|
||||
"This must be a bug, please report."
|
||||
types = []
|
||||
for arg in args:
|
||||
types.append(str(type(arg)))
|
||||
|
||||
torch_types = []
|
||||
jittor_types = []
|
||||
paddle_types = []
|
||||
for type_name in types:
|
||||
if 'torch' in type_name:
|
||||
torch_types.append(type_name)
|
||||
if 'paddle' in type_name:
|
||||
paddle_types.append(type_name)
|
||||
if 'jittor' in type_name:
|
||||
jittor_types.append(type_name)
|
||||
|
||||
# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上
|
||||
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
|
||||
backend = 'torch'
|
||||
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0:
|
||||
backend = 'jittor'
|
||||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0:
|
||||
backend = 'paddle'
|
||||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0:
|
||||
# 直接使用default的backend就好了
|
||||
backend = None
|
||||
else:
|
||||
types = list(set(torch_types + jittor_types + paddle_types))
|
||||
raise RuntimeError(
|
||||
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of "
|
||||
f"using backend=auto.")
|
||||
|
||||
self._convert_backend(backend)
|
75
fastNLP/core/metrics/backend/backend.py
Normal file
75
fastNLP/core/metrics/backend/backend.py
Normal file
@ -0,0 +1,75 @@
|
||||
from ..utils import AggregateMethodError
|
||||
|
||||
|
||||
class Backend:
|
||||
"""
|
||||
Backend 及其子类的所有方法都必须是无状态的。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._specified = False
|
||||
|
||||
def aggregate(self, tensor, method: str):
|
||||
"""
|
||||
聚集结果,并根据method计算后,返回结果
|
||||
"""
|
||||
if method is not None:
|
||||
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True)
|
||||
|
||||
return tensor
|
||||
|
||||
def create_tensor(self, value: float):
|
||||
"""
|
||||
创建tensor,并且填入value作为值
|
||||
"""
|
||||
return value
|
||||
|
||||
def fill_value(self, tensor, value: float):
|
||||
"""
|
||||
将tensor的值设置为value
|
||||
|
||||
"""
|
||||
return value
|
||||
|
||||
def get_scalar(self, tensor) -> float:
|
||||
"""
|
||||
tensor的saclar值
|
||||
|
||||
:param tensor:
|
||||
:return:
|
||||
"""
|
||||
return tensor
|
||||
|
||||
def is_specified(self) -> bool:
|
||||
"""
|
||||
判断是否是某种框架的backend
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._specified
|
||||
|
||||
def tensor2numpy(self, tensor):
|
||||
"""
|
||||
将tensor转为numpy
|
||||
|
||||
:param tensor:
|
||||
:return:
|
||||
"""
|
||||
return tensor
|
||||
|
||||
def move_tensor_to_device(self, tensor, device):
|
||||
"""
|
||||
"""
|
||||
return tensor
|
||||
|
||||
def all_gather_object(self, obj, group=None):
|
||||
"""
|
||||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。
|
||||
|
||||
:param obj:
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.")
|
||||
|
1
fastNLP/core/metrics/backend/jittor_backend/__init__.py
Normal file
1
fastNLP/core/metrics/backend/jittor_backend/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
72
fastNLP/core/metrics/backend/jittor_backend/backend.py
Normal file
72
fastNLP/core/metrics/backend/jittor_backend/backend.py
Normal file
@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
|
||||
|
||||
class JittorBackend(Backend):
|
||||
|
||||
def __init__(self):
|
||||
super(JittorBackend, self).__init__()
|
||||
self._specified = True
|
||||
|
||||
def aggregate(self, tensor, method: str):
|
||||
"""
|
||||
聚集结果,并根据method计算后,返回结果
|
||||
"""
|
||||
return tensor
|
||||
|
||||
def create_tensor(self, value: float):
|
||||
"""
|
||||
创建tensor,并且填入value作为值
|
||||
"""
|
||||
value = jittor.Var(value)
|
||||
return value
|
||||
|
||||
def fill_value(self, tensor, value: float):
|
||||
"""
|
||||
将tensor的值设置为value
|
||||
|
||||
"""
|
||||
value = jittor.full_like(tensor, value)
|
||||
return value
|
||||
|
||||
def get_scalar(self, tensor) -> float:
|
||||
"""
|
||||
tensor的saclar值
|
||||
|
||||
:param tensor:
|
||||
:return:
|
||||
"""
|
||||
return tensor.item()
|
||||
|
||||
def is_specified(self) -> bool:
|
||||
"""
|
||||
判断是否是某种框架的backend
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self._specified
|
||||
|
||||
def tensor2numpy(self, tensor):
|
||||
"""
|
||||
将tensor转为numpy
|
||||
|
||||
:param tensor:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(tensor, jittor.Var):
|
||||
return tensor.detach().numpy()
|
||||
elif isinstance(tensor, np.array):
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
|
||||
|
||||
def move_tensor_to_device(self, tensor, device):
|
||||
"""
|
||||
jittor的没有转移设备的函数,因此该函数实际上无效
|
||||
"""
|
||||
return tensor
|
5
fastNLP/core/metrics/backend/paddle_backend/__init__.py
Normal file
5
fastNLP/core/metrics/backend/paddle_backend/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
__all__ = [
|
||||
'PaddleBackend'
|
||||
]
|
||||
|
||||
from .backend import Backend as PaddleBackend
|
126
fastNLP/core/metrics/backend/paddle_backend/backend.py
Normal file
126
fastNLP/core/metrics/backend/paddle_backend/backend.py
Normal file
@ -0,0 +1,126 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
from fastNLP.core.utils.paddle_utils import paddle_to
|
||||
from fastNLP.core.metrics.utils import AggregateMethodError
|
||||
from fastNLP.core.utils import is_in_paddle_dist
|
||||
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
from paddle.fluid.dygraph import parallel_helper
|
||||
|
||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
|
||||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
|
||||
paddle.distributed.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
||||
class PaddleBackend(Backend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._specified = True
|
||||
|
||||
def aggregate(self, tensor, method: str):
|
||||
"""
|
||||
聚集结果,并根据method计算后,返回结果
|
||||
"""
|
||||
if isinstance(tensor, paddle.Tensor):
|
||||
if parallel_helper._is_parallel_ctx_initialized():
|
||||
if method is None:
|
||||
raise AggregateMethodError(should_have_aggregate_method=True)
|
||||
tensor = self._gather_all(tensor)
|
||||
if isinstance(tensor[0], paddle.Tensor):
|
||||
tensor = paddle.stack(tensor)
|
||||
# 第一步, aggregate结果
|
||||
if method == 'sum':
|
||||
tensor = paddle.sum(tensor, dim=0)
|
||||
elif method == 'mean':
|
||||
tensor = paddle.mean(tensor, dim=0)
|
||||
elif method == 'max':
|
||||
tensor, _ = paddle.max(tensor, dim=0)
|
||||
elif method == 'min':
|
||||
tensor, _ = paddle.min(tensor, dim=0)
|
||||
else:
|
||||
raise AggregateMethodError(should_have_aggregate_method=False)
|
||||
|
||||
return tensor
|
||||
|
||||
def create_tensor(self, value: float):
|
||||
"""
|
||||
创建tensor,并且填入value作为值
|
||||
"""
|
||||
tensor = paddle.ones((1,)).fill_(value)
|
||||
return tensor
|
||||
|
||||
def fill_value(self, tensor, value: float):
|
||||
"""
|
||||
将tensor的值设置为value
|
||||
|
||||
"""
|
||||
tensor.fill_(value)
|
||||
return tensor
|
||||
|
||||
def get_scalar(self, tensor) -> float:
|
||||
return tensor.item()
|
||||
|
||||
def tensor2numpy(self, tensor) -> np.array:
|
||||
if isinstance(tensor, paddle.Tensor):
|
||||
return tensor.cpu().detach().numpy()
|
||||
elif isinstance(tensor, np.array):
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
|
||||
|
||||
@staticmethod
|
||||
def _gather_all(result, group: Optional[Any] = None) -> List:
|
||||
"""
|
||||
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding
|
||||
"""
|
||||
# TODO check 正确性
|
||||
if group is None:
|
||||
group = paddle.distributed.get_group(0)
|
||||
|
||||
world_size = group.nranks
|
||||
paddle.distributed.barrier(group=group)
|
||||
|
||||
# 张量为 标量的情况,简单地gather就好
|
||||
if result.ndim == 0:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 获得 result 的 shape
|
||||
local_size = paddle.to_tensor(result.shape)
|
||||
# 将 group 中所有 result 的大小聚合在一起
|
||||
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)]
|
||||
paddle.distributed.all_gather(local_sizes, local_size, group=group)
|
||||
# 堆叠后,计算出 shape 每一维度的最大值
|
||||
max_size = paddle.stack(local_sizes).max(axis=0).values
|
||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
||||
|
||||
# 如果所有的结果大小相同,那么可以直接聚合
|
||||
if all_sizes_equal:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 否则,padding 与最大的张量对齐
|
||||
pad_dims = []
|
||||
pad_by = (max_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
result_padded = paddle.nn.functional.pad(result, pad_dims)
|
||||
# 重新进行聚合
|
||||
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)]
|
||||
paddle.distributed.all_gather(gathered_result, result_padded, group)
|
||||
for idx, item_size in enumerate(local_sizes):
|
||||
slice_param = [slice(dim_size) for dim_size in item_size]
|
||||
gathered_result[idx] = gathered_result[idx][slice_param]
|
||||
return gathered_result
|
||||
|
||||
def move_tensor_to_device(self, tensor, device):
|
||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug?
|
||||
if is_in_paddle_dist():
|
||||
device = get_device_from_visible(device)
|
||||
return paddle_to(tensor, device)
|
||||
|
6
fastNLP/core/metrics/backend/torch_backend/__init__.py
Normal file
6
fastNLP/core/metrics/backend/torch_backend/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
__all__ = [
|
||||
'TorchBackend'
|
||||
]
|
||||
|
||||
|
||||
from .backend import Backend as TorchBackend
|
154
fastNLP/core/metrics/backend/torch_backend/backend.py
Normal file
154
fastNLP/core/metrics/backend/torch_backend/backend.py
Normal file
@ -0,0 +1,154 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
from fastNLP.core.metrics.utils import AggregateMethodError
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather
|
||||
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
||||
|
||||
class TorchBackend(Backend):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._specified = True
|
||||
|
||||
def aggregate(self, tensor, method: str):
|
||||
"""
|
||||
聚集结果,并根据method计算后,返回结果。
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if dist.is_initialized():
|
||||
if method is None:
|
||||
raise AggregateMethodError(should_have_aggregate_method=True)
|
||||
tensor = self._gather_all(tensor)
|
||||
if isinstance(tensor[0], torch.Tensor):
|
||||
tensor = torch.stack(tensor)
|
||||
# 第一步, aggregate结果
|
||||
if method == 'sum':
|
||||
tensor = torch.sum(tensor, dim=0)
|
||||
elif method == 'mean':
|
||||
tensor = torch.mean(tensor, dim=0)
|
||||
elif method == 'max':
|
||||
tensor, _ = torch.max(tensor, dim=0)
|
||||
elif method == 'min':
|
||||
tensor, _ = torch.min(tensor, dim=0)
|
||||
else:
|
||||
raise AggregateMethodError(should_have_aggregate_method=False)
|
||||
|
||||
return tensor
|
||||
|
||||
def create_tensor(self, value: float):
|
||||
"""
|
||||
创建tensor,并且填入value作为值
|
||||
"""
|
||||
tensor = torch.ones(1).fill_(value)
|
||||
return tensor
|
||||
|
||||
def fill_value(self, tensor, value: float):
|
||||
"""
|
||||
将tensor的值设置为value
|
||||
|
||||
"""
|
||||
tensor.fill_(value)
|
||||
return tensor
|
||||
|
||||
def get_scalar(self, tensor) -> float:
|
||||
return tensor.item()
|
||||
|
||||
@staticmethod
|
||||
def _gather_all(result, group: Optional[Any] = None) -> List:
|
||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: list with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
"""
|
||||
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
|
||||
# convert tensors to contiguous format
|
||||
result = result.contiguous()
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
dist.barrier(group=group)
|
||||
|
||||
# if the tensor is scalar, things are easy
|
||||
if result.ndim == 0:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 1. Gather sizes of all tensors
|
||||
local_size = torch.tensor(result.shape, device=result.device)
|
||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
|
||||
dist.all_gather(local_sizes, local_size, group=group)
|
||||
max_size = torch.stack(local_sizes).max(dim=0).values
|
||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
||||
|
||||
# 2. If shapes are all the same, then do a simple gather:
|
||||
if all_sizes_equal:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
|
||||
pad_dims = []
|
||||
pad_by = (max_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
result_padded = torch.nn.functional.pad(result, pad_dims)
|
||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_result, result_padded, group)
|
||||
for idx, item_size in enumerate(local_sizes):
|
||||
slice_param = [slice(dim_size) for dim_size in item_size]
|
||||
gathered_result[idx] = gathered_result[idx][slice_param]
|
||||
return gathered_result
|
||||
|
||||
def tensor2numpy(self, tensor) -> np.array:
|
||||
"""
|
||||
将对应的tensor转为numpy对象
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return tensor.cpu().detach().numpy()
|
||||
elif isinstance(tensor, np.ndarray):
|
||||
return tensor
|
||||
elif isinstance(tensor, (float, int)):
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")
|
||||
|
||||
@staticmethod
|
||||
def is_distributed() -> bool:
|
||||
"""
|
||||
:return:
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
def move_tensor_to_device(self, tensor, device):
|
||||
return tensor.to(device)
|
||||
|
||||
def all_gather_object(self, obj, group=None) -> List:
|
||||
if self.is_distributed():
|
||||
obj_list = fastnlp_torch_all_gather(obj, group=group)
|
||||
return obj_list
|
||||
return [obj]
|
||||
|
142
fastNLP/core/metrics/classify_f1_pre_rec_metric.py
Normal file
142
fastNLP/core/metrics/classify_f1_pre_rec_metric.py
Normal file
@ -0,0 +1,142 @@
|
||||
__all__ = [
|
||||
'ClassifyFPreRecMetric'
|
||||
]
|
||||
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
from .metric import Metric
|
||||
from .backend import Backend
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.utils.utils import seq_len_to_mask
|
||||
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
|
||||
|
||||
class ClassifyFPreRecMetric(Metric):
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False,
|
||||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None,
|
||||
only_gross: bool = True, f_type='micro', beta=1) -> None:
|
||||
super(ClassifyFPreRecMetric, self).__init__(backend=backend,
|
||||
aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
|
||||
|
||||
self.ignore_labels = ignore_labels
|
||||
self.f_type = f_type
|
||||
self.beta = beta
|
||||
self.beta_square = self.beta ** 2
|
||||
self.only_gross = only_gross
|
||||
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\
|
||||
defaultdict(partial(self.register_element, aggregate_method='sum')),\
|
||||
defaultdict(partial(self.register_element, aggregate_method='sum'))
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
r"""
|
||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
|
||||
|
||||
:return dict evaluate_result: {"acc": float}
|
||||
"""
|
||||
evaluate_result = {}
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._fn.keys())
|
||||
tags.update(set(self._fp.keys()))
|
||||
tags.update(set(self._tp.keys()))
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
for tag in tags:
|
||||
if self.tag_vocab is not None:
|
||||
tag_name = self.tag_vocab.to_word(tag)
|
||||
else:
|
||||
tag_name = int(tag)
|
||||
tp = self._tp[tag]
|
||||
fn = self._fn[tag]
|
||||
fp = self._fp[tag]
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
|
||||
f_sum += f
|
||||
pre_sum += pre
|
||||
rec_sum += rec
|
||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
|
||||
f_key = 'f-{}'.format(tag_name)
|
||||
pre_key = 'pre-{}'.format(tag_name)
|
||||
rec_key = 'rec-{}'.format(tag_name)
|
||||
evaluate_result[f_key] = f
|
||||
evaluate_result[pre_key] = pre
|
||||
evaluate_result[rec_key] = rec
|
||||
|
||||
if self.f_type == 'macro':
|
||||
evaluate_result['f'] = f_sum / len(tags)
|
||||
evaluate_result['pre'] = pre_sum / len(tags)
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(self._tp.values()),
|
||||
sum(self._fn.values()),
|
||||
sum(self._fp.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
|
||||
|
||||
for key, value in evaluate_result.items():
|
||||
evaluate_result[key] = round(value, 6)
|
||||
|
||||
return evaluate_result
|
||||
|
||||
def update(self, pred, target, seq_len=None):
|
||||
pred = self.tensor2numpy(pred)
|
||||
target = self.tensor2numpy(target)
|
||||
if seq_len is not None:
|
||||
seq_len = self.tensor2numpy(seq_len)
|
||||
|
||||
if seq_len is not None and target.ndim > 1:
|
||||
max_len = target.ndim[-1]
|
||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
|
||||
else:
|
||||
masks = None
|
||||
|
||||
if pred.ndim == target.ndim:
|
||||
if len(pred.flatten()) != len(target.flatten()):
|
||||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers."
|
||||
f" while target have element numbers:{len(pred.flatten())}, "
|
||||
f"pred have element numbers: {len(target.flatten())}")
|
||||
|
||||
pass
|
||||
elif len(pred.ndim) == len(target.ndim) + 1:
|
||||
pred = pred.argmax(axis=-1)
|
||||
if seq_len is None and len(target.ndim) > 1:
|
||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
|
||||
else:
|
||||
raise RuntimeError(f"when pred have "
|
||||
f"size:{pred.ndim}, target should have size: {pred.ndim} or "
|
||||
f"{pred.ndim[:-1]}, got {target.ndim}.")
|
||||
if masks is not None:
|
||||
target = target * masks
|
||||
pred = pred * masks
|
||||
target_idxes = set(target.reshape(-1).tolist())
|
||||
for target_idx in target_idxes:
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()
|
||||
|
||||
|
281
fastNLP/core/metrics/element.py
Normal file
281
fastNLP/core/metrics/element.py
Normal file
@ -0,0 +1,281 @@
|
||||
__all__ = [
|
||||
'Element'
|
||||
]
|
||||
|
||||
import os
|
||||
|
||||
from .backend import Backend, AutoBackend
|
||||
from fastNLP.core.log import logger
|
||||
from .utils import AggregateMethodError
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
|
||||
|
||||
|
||||
class Element:
|
||||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
|
||||
self.init_value = value
|
||||
self.aggregate_method = aggregate_method
|
||||
self.name = name
|
||||
if backend == 'auto':
|
||||
raise RuntimeError("You have to specify the backend.")
|
||||
elif isinstance(backend, AutoBackend):
|
||||
self.backend = backend
|
||||
else:
|
||||
self.backend = AutoBackend(backend)
|
||||
|
||||
if self.backend.is_specified():
|
||||
value = self.backend.create_tensor(self.init_value)
|
||||
else:
|
||||
value = None
|
||||
self._value = value
|
||||
self.device = None
|
||||
|
||||
def aggregate(self):
|
||||
"""
|
||||
自动aggregate对应的元素
|
||||
|
||||
"""
|
||||
try:
|
||||
self._value = self.backend.aggregate(self._value, self.aggregate_method)
|
||||
except AggregateMethodError as e:
|
||||
msg = 'If you see this message, please report a bug.'
|
||||
if self.name and e.should_have_aggregate_method:
|
||||
msg = f"Element:{self.name} has no specified `aggregate_method`."
|
||||
elif e.should_have_aggregate_method:
|
||||
msg = "Element has no specified `aggregate_method`."
|
||||
elif self.name and not e.should_have_aggregate_method:
|
||||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
|
||||
f'aggregate_method:{self.aggregate_method}.'
|
||||
elif not e.should_have_aggregate_method:
|
||||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
|
||||
f'aggregate_method:{self.aggregate_method}.'
|
||||
if e.only_warn:
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
logger.warning(msg)
|
||||
self._value = self.backend.aggregate(self._value, method=None)
|
||||
else:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def reset(self):
|
||||
if self.backend.is_specified():
|
||||
self._value = self.backend.fill_value(self._value, self.init_value)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
@value.setter
|
||||
def value(self, value):
|
||||
self._check_value_initialized()
|
||||
self._value = value
|
||||
|
||||
@value.getter
|
||||
def value(self):
|
||||
self._check_value_initialized()
|
||||
return self._value
|
||||
|
||||
def get_scalar(self) -> float:
|
||||
return self.backend.get_scalar(self._value)
|
||||
|
||||
def fill_value(self, value):
|
||||
self._value = self.backend.fill_value(self._value, value)
|
||||
|
||||
def to(self, device):
|
||||
# device这里如何处理呢?
|
||||
if self._value is not None:
|
||||
self._value = self.backend.move_tensor_to_device(self._value, device)
|
||||
self.device = device
|
||||
|
||||
def _check_value_initialized(self):
|
||||
if self._value is None:
|
||||
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \
|
||||
f"initialization."
|
||||
self._value = self.backend.create_tensor(self.init_value)
|
||||
if self.device is not None:
|
||||
self.to(device=self.device)
|
||||
|
||||
def _check_value_when_call(self):
|
||||
if self.value is None:
|
||||
prefix = f'Element:`{self.name}`' if self.name else 'Element'
|
||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
|
||||
"element, or use it after it being used by the `Metric.compute()` method.")
|
||||
|
||||
def __add__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value += other.value
|
||||
else:
|
||||
self.value += other
|
||||
return self
|
||||
|
||||
def __radd__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value += other.value
|
||||
else:
|
||||
self.value += other
|
||||
return self
|
||||
|
||||
def __sub__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value -= other.value
|
||||
else:
|
||||
self.value -= other
|
||||
return self
|
||||
|
||||
def __rsub__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value -= other.value
|
||||
else:
|
||||
self.value -= other
|
||||
return self
|
||||
|
||||
def __mul__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value *= other.value
|
||||
else:
|
||||
self.value *= other
|
||||
return self
|
||||
|
||||
def __imul__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value *= other.value
|
||||
else:
|
||||
self.value *= other
|
||||
return self
|
||||
|
||||
def __floordiv__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value //= other.value
|
||||
else:
|
||||
self.value //= other
|
||||
return self
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value //= other.value
|
||||
else:
|
||||
self.value //= other
|
||||
return self
|
||||
|
||||
def __truediv__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value /= other.value
|
||||
else:
|
||||
self.value /= other
|
||||
return self
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value /= other.value
|
||||
else:
|
||||
self.value /= other
|
||||
return self
|
||||
|
||||
def __mod__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value %= other.value
|
||||
else:
|
||||
self.value %= other
|
||||
return self
|
||||
|
||||
def __rmod__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value /= other.value
|
||||
else:
|
||||
self.value /= other
|
||||
return self
|
||||
|
||||
def __pow__(self, other, modulo=None):
|
||||
self._check_value_when_call()
|
||||
if modulo is None:
|
||||
if isinstance(other, Element):
|
||||
self.value **= other.value
|
||||
else:
|
||||
self.value **= other
|
||||
else:
|
||||
if isinstance(other, Element):
|
||||
self.value = pow(self.value, other.value, modulo)
|
||||
else:
|
||||
self.value = pow(self.value, other, modulo)
|
||||
return self
|
||||
|
||||
def __rpow__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
self.value **= other.value
|
||||
else:
|
||||
self.value **= other
|
||||
return self
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value < other.value
|
||||
else:
|
||||
return self.value < other
|
||||
|
||||
def __le__(self, other) -> bool:
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value <= other.value
|
||||
else:
|
||||
return self.value <= other
|
||||
|
||||
def __eq__(self, other):
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value == other.value
|
||||
else:
|
||||
return self.value == other
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value != other.value
|
||||
else:
|
||||
return self.value != other
|
||||
|
||||
def __ge__(self, other) -> bool:
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value >= other.value
|
||||
else:
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other) -> bool:
|
||||
self._check_value_when_call()
|
||||
if isinstance(other, Element):
|
||||
return self.value > other.value
|
||||
else:
|
||||
return self.value > other
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
"""
|
||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if self._value is None:
|
||||
prefix = f'Element:`{self.name}`' if self.name else 'Element'
|
||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
|
||||
"element, or use it after it being used by the `Metric.compute()` method.")
|
||||
return getattr(self._value, item)
|
||||
except AttributeError as e:
|
||||
raise e
|
184
fastNLP/core/metrics/metric.py
Normal file
184
fastNLP/core/metrics/metric.py
Normal file
@ -0,0 +1,184 @@
|
||||
__all__ = [
|
||||
'Metric'
|
||||
]
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from typing import Union
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.metrics.backend import Backend, AutoBackend
|
||||
from fastNLP.core.metrics.element import Element
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
|
||||
"""
|
||||
|
||||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。
|
||||
"""
|
||||
self.backend = AutoBackend(backend)
|
||||
self._updated = False
|
||||
self.get_metric = self._sync_get_metric(self.get_metric)
|
||||
self.update = self._wrap_update(self.update)
|
||||
self.reset = self._wrap_auto_reset_elements(self.reset)
|
||||
self.aggregate_when_get_metric = aggregate_when_get_metric
|
||||
self._cannot_change_element = False
|
||||
self._elements = {}
|
||||
|
||||
@property
|
||||
def elements(self) -> dict:
|
||||
return self._elements
|
||||
|
||||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
|
||||
"""
|
||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的
|
||||
tensor 直接进行加减乘除计算即可。
|
||||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。
|
||||
|
||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。
|
||||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值
|
||||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。
|
||||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为
|
||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。
|
||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入
|
||||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含
|
||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测
|
||||
到任何一种 tensor ,就默认使用 float 类型作为 element 。
|
||||
:return: 注册的 Element 对象
|
||||
"""
|
||||
if backend == 'auto':
|
||||
backend = self.backend
|
||||
else:
|
||||
backend = AutoBackend(backend)
|
||||
|
||||
# 当name为None,默认为变量取得变量名
|
||||
if name is None:
|
||||
name = f'ele_var_{len(self._elements)}'
|
||||
|
||||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
|
||||
self.elements[name] = element
|
||||
setattr(self, name, element)
|
||||
return element
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def _wrap_auto_reset_elements(self, reset):
|
||||
@functools.wraps(reset)
|
||||
def _wrap_reset(*args, **kwargs):
|
||||
self._updated = False
|
||||
for ele in self.elements.values():
|
||||
ele.reset()
|
||||
reset(*args, **kwargs)
|
||||
|
||||
return _wrap_reset
|
||||
|
||||
def _sync_get_metric(self, get_metric):
|
||||
@functools.wraps(get_metric)
|
||||
def _wrap_get_metric(*args, **kwargs):
|
||||
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \
|
||||
f"get_metric()."
|
||||
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
|
||||
results = get_metric(*args, **kwargs)
|
||||
return results
|
||||
|
||||
return _wrap_get_metric
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
|
||||
if key in self.elements and value is not self.elements[key]:
|
||||
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}")
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _wrap_update(self, update):
|
||||
@functools.wraps(update)
|
||||
def _wrap_update(*args, **kwargs):
|
||||
self.check_backend(*args, **kwargs)
|
||||
self._cannot_change_element = True
|
||||
self._updated = True
|
||||
return update(*args, **kwargs)
|
||||
|
||||
return _wrap_update
|
||||
|
||||
def check_backend(self, *args, **kwargs):
|
||||
if not self.backend.is_specified():
|
||||
_args = []
|
||||
for arg in args:
|
||||
_args.append(arg)
|
||||
for arg in kwargs.values():
|
||||
_args.append(arg)
|
||||
self.backend.choose_real_backend(_args)
|
||||
|
||||
@contextmanager
|
||||
def sync(self, recover=True, aggregate=False):
|
||||
"""
|
||||
在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的
|
||||
值恢复到计算前的值。
|
||||
|
||||
"""
|
||||
keep_value = {}
|
||||
if aggregate:
|
||||
for name, element in self.elements.items():
|
||||
# 保存过去的值
|
||||
keep_value[name] = element.get_scalar()
|
||||
# 聚合结果
|
||||
element.aggregate()
|
||||
|
||||
yield
|
||||
|
||||
if recover and aggregate:
|
||||
for name, element in self.elements.items():
|
||||
# 恢复结果
|
||||
if name in keep_value:
|
||||
element.fill_value(value=keep_value.get(name))
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_metric(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_auto_aggregate_when_get_metric(self, flag: bool):
|
||||
"""
|
||||
设置是否在 get_metric 的时候自动 aggregate
|
||||
|
||||
"""
|
||||
self.aggregate_when_get_metric = flag
|
||||
|
||||
def __getattr__(self, name: str) -> Element:
|
||||
if 'elements' in self.__dict__:
|
||||
elements = self.__dict__['elements']
|
||||
if name in elements:
|
||||
return elements[name]
|
||||
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name))
|
||||
|
||||
def tensor2numpy(self, tensor) -> np.array:
|
||||
"""
|
||||
将tensor向量转为numpy类型变量
|
||||
|
||||
:param tensor:
|
||||
:return:
|
||||
"""
|
||||
return self.backend.tensor2numpy(tensor)
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
将所有的 element 变量移动到 device 设备上
|
||||
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
for element in self.elements.values():
|
||||
element.to(device)
|
344
fastNLP/core/metrics/span_f1_pre_rec_metric.py
Normal file
344
fastNLP/core/metrics/span_f1_pre_rec_metric.py
Normal file
@ -0,0 +1,344 @@
|
||||
__all__ = [
|
||||
'SpanFPreRecMetric'
|
||||
]
|
||||
|
||||
from typing import Union, List, Optional
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from fastNLP.core.metrics.backend import Backend
|
||||
from fastNLP.core.metrics.metric import Metric
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
|
||||
r"""
|
||||
检查vocab中的tag是否与encoding_type是匹配的
|
||||
|
||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
|
||||
:param encoding_type: bio, bmes, bioes, bmeso
|
||||
:return:
|
||||
"""
|
||||
tag_set = set()
|
||||
unk_token = '<unk>'
|
||||
pad_token = '<pad>'
|
||||
if isinstance(tag_vocab, Vocabulary):
|
||||
unk_token = tag_vocab.unknown
|
||||
pad_token = tag_vocab.padding
|
||||
tag_vocab = tag_vocab.idx2word
|
||||
for idx, tag in tag_vocab.items():
|
||||
if tag in (unk_token, pad_token):
|
||||
continue
|
||||
tag = tag[:1].lower()
|
||||
tag_set.add(tag)
|
||||
|
||||
tags = encoding_type
|
||||
for tag in tag_set:
|
||||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
|
||||
f"encoding_type."
|
||||
tags = tags.replace(tag, '') # 删除该值
|
||||
if tags: # 如果不为空,说明出现了未使用的tag
|
||||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
|
||||
"encoding_type.")
|
||||
|
||||
|
||||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str:
|
||||
r"""
|
||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio
|
||||
|
||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。
|
||||
:return:
|
||||
"""
|
||||
tag_set = set()
|
||||
unk_token = '<unk>'
|
||||
pad_token = '<pad>'
|
||||
if isinstance(tag_vocab, Vocabulary):
|
||||
unk_token = tag_vocab.unknown
|
||||
pad_token = tag_vocab.padding
|
||||
tag_vocab = tag_vocab.idx2word
|
||||
for idx, tag in tag_vocab.items():
|
||||
if tag in (unk_token, pad_token):
|
||||
continue
|
||||
tag = tag[:1].lower()
|
||||
tag_set.add(tag)
|
||||
|
||||
bmes_tag_set = set('bmes')
|
||||
if tag_set == bmes_tag_set:
|
||||
return 'bmes'
|
||||
bio_tag_set = set('bio')
|
||||
if tag_set == bio_tag_set:
|
||||
return 'bio'
|
||||
bmeso_tag_set = set('bmeso')
|
||||
if tag_set == bmeso_tag_set:
|
||||
return 'bmeso'
|
||||
bioes_tag_set = set('bioes')
|
||||
if tag_set == bioes_tag_set:
|
||||
return 'bioes'
|
||||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support "
|
||||
"'bio', 'bmes', 'bmeso', 'bioes' type.")
|
||||
|
||||
|
||||
def _bmes_tag_to_spans(tags, ignore_labels=None):
|
||||
r"""
|
||||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
|
||||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
|
||||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列
|
||||
|
||||
:param tags: List[str],
|
||||
:param ignore_labels: List[str], 在该list中的label将被忽略
|
||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
|
||||
"""
|
||||
ignore_labels = set(ignore_labels) if ignore_labels else set()
|
||||
|
||||
spans = []
|
||||
prev_bmes_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bmes_tag, label = tag[:1], tag[2:]
|
||||
if bmes_tag in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bmes_tag = bmes_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1))
|
||||
for span in spans
|
||||
if span[0] not in ignore_labels
|
||||
]
|
||||
|
||||
|
||||
def _bmeso_tag_to_spans(tags, ignore_labels=None):
|
||||
r"""
|
||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
|
||||
返回[('singer', (1, 4))] (左闭右开区间)
|
||||
|
||||
:param tags: List[str],
|
||||
:param ignore_labels: List[str], 在该list中的label将被忽略
|
||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
|
||||
"""
|
||||
ignore_labels = set(ignore_labels) if ignore_labels else set()
|
||||
|
||||
spans = []
|
||||
prev_bmes_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bmes_tag, label = tag[:1], tag[2:]
|
||||
if bmes_tag in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif bmes_tag == 'o':
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bmes_tag = bmes_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1))
|
||||
for span in spans
|
||||
if span[0] not in ignore_labels
|
||||
]
|
||||
|
||||
|
||||
def _bioes_tag_to_spans(tags, ignore_labels=None):
|
||||
r"""
|
||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。
|
||||
返回[('singer', (1, 4))] (左闭右开区间)
|
||||
|
||||
:param tags: List[str],
|
||||
:param ignore_labels: List[str], 在该list中的label将被忽略
|
||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
|
||||
"""
|
||||
ignore_labels = set(ignore_labels) if ignore_labels else set()
|
||||
|
||||
spans = []
|
||||
prev_bioes_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bioes_tag, label = tag[:1], tag[2:]
|
||||
if bioes_tag in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif bioes_tag == 'o':
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bioes_tag = bioes_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1))
|
||||
for span in spans
|
||||
if span[0] not in ignore_labels
|
||||
]
|
||||
|
||||
|
||||
def _bio_tag_to_spans(tags, ignore_labels=None):
|
||||
r"""
|
||||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
|
||||
返回[('singer', (1, 4))] (左闭右开区间)
|
||||
|
||||
:param tags: List[str],
|
||||
:param ignore_labels: List[str], 在该list中的label将被忽略
|
||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
|
||||
"""
|
||||
ignore_labels = set(ignore_labels) if ignore_labels else set()
|
||||
|
||||
spans = []
|
||||
prev_bio_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bio_tag, label = tag[:1], tag[2:]
|
||||
if bio_tag == 'b':
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif bio_tag == 'o': # o tag does not count
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bio_tag = bio_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
|
||||
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
|
||||
|
||||
class SpanFPreRecMetric(Metric):
|
||||
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None,
|
||||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro',
|
||||
beta=1, aggregate_when_get_metric: bool = True,) -> None:
|
||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
|
||||
if not isinstance(tag_vocab, Vocabulary):
|
||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
|
||||
if encoding_type:
|
||||
encoding_type = encoding_type.lower()
|
||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
|
||||
self.encoding_type = encoding_type
|
||||
else:
|
||||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab)
|
||||
|
||||
if self.encoding_type == 'bmes':
|
||||
self.tag_to_span_func = _bmes_tag_to_spans
|
||||
elif self.encoding_type == 'bio':
|
||||
self.tag_to_span_func = _bio_tag_to_spans
|
||||
elif self.encoding_type == 'bmeso':
|
||||
self.tag_to_span_func = _bmeso_tag_to_spans
|
||||
elif self.encoding_type == 'bioes':
|
||||
self.tag_to_span_func = _bioes_tag_to_spans
|
||||
else:
|
||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.")
|
||||
|
||||
self.ignore_labels = ignore_labels
|
||||
self.f_type = f_type
|
||||
self.beta = beta
|
||||
self.beta_square = self.beta ** 2
|
||||
self.only_gross = only_gross
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
evaluate_result = {}
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._false_negatives.keys())
|
||||
tags.update(set(self._false_positives.keys()))
|
||||
tags.update(set(self._true_positives.keys()))
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
for tag in tags:
|
||||
tp = self._true_positives[tag].get_scalar()
|
||||
fn = self._false_negatives[tag].get_scalar()
|
||||
fp = self._false_positives[tag].get_scalar()
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
|
||||
f_sum += f
|
||||
pre_sum += pre
|
||||
rec_sum += rec
|
||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
|
||||
f_key = 'f-{}'.format(tag)
|
||||
pre_key = 'pre-{}'.format(tag)
|
||||
rec_key = 'rec-{}'.format(tag)
|
||||
evaluate_result[f_key] = f
|
||||
evaluate_result[pre_key] = pre
|
||||
evaluate_result[rec_key] = rec
|
||||
|
||||
if self.f_type == 'macro':
|
||||
evaluate_result['f'] = f_sum / len(tags)
|
||||
evaluate_result['pre'] = pre_sum / len(tags)
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(val.get_scalar() for val in self._true_positives.values()),
|
||||
sum(val.get_scalar() for val in self._false_negatives.values()),
|
||||
sum(val.get_scalar() for val in self._false_positives.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
|
||||
for key, value in evaluate_result.items():
|
||||
evaluate_result[key] = round(value, 6)
|
||||
|
||||
return evaluate_result
|
||||
|
||||
def update(self, pred, target, seq_len: Optional[List] = None) -> None:
|
||||
r"""update函数将针对一个批次的预测结果做评价指标的累计
|
||||
|
||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
|
||||
:param target: [batch, seq_len], 真实值
|
||||
:param seq_len: [batch] 文本长度标记
|
||||
:return:
|
||||
"""
|
||||
pred = self.tensor2numpy(pred)
|
||||
target = self.tensor2numpy(target)
|
||||
|
||||
if pred.ndim == target.ndim and target.ndim == 2:
|
||||
pass
|
||||
|
||||
elif pred.ndim == target.ndim + 1 and target.ndim == 2:
|
||||
num_classes = pred.shape[-1]
|
||||
pred = pred.argmax(axis=-1)
|
||||
if (target >= num_classes).any():
|
||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
|
||||
"id >= {}, the number of classes.".format(num_classes))
|
||||
else:
|
||||
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or "
|
||||
f"{pred.shape[:-1]}, got {target.ndim}.")
|
||||
|
||||
batch_size = pred.shape[0]
|
||||
pred = pred.tolist()
|
||||
target = target.tolist()
|
||||
for i in range(batch_size):
|
||||
pred_tags = pred[i][:int(seq_len[i])]
|
||||
gold_tags = target[i][:int(seq_len[i])]
|
||||
|
||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
|
||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
|
||||
|
||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
|
||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
|
||||
|
||||
for span in pred_spans:
|
||||
if span in gold_spans:
|
||||
self._true_positives[span[0]] += 1
|
||||
gold_spans.remove(span)
|
||||
else:
|
||||
self._false_positives[span[0]] += 1
|
||||
for span in gold_spans:
|
||||
self._false_negatives[span[0]] += 1
|
91
fastNLP/core/metrics/utils.py
Normal file
91
fastNLP/core/metrics/utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
__all__ = [
|
||||
'func_post_proc'
|
||||
]
|
||||
|
||||
from typing import Any
|
||||
from functools import wraps
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.envs.utils import _module_available
|
||||
|
||||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
|
||||
if _IS_TORCHMETRICS_AVAILABLE:
|
||||
from torchmetrics import Metric as torchmetrics_Metric
|
||||
|
||||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
|
||||
if _IS_ALLENNLP_AVAILABLE:
|
||||
from allennlp.training.metrics import Metric as allennlp_Metric
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
from paddle.metric import Metric as paddle_Metric
|
||||
|
||||
|
||||
def _is_torchmetrics_metric(metric: Any) -> bool:
|
||||
"""
|
||||
检查输入的对象是否为torchmetrics对象
|
||||
|
||||
:param metric:
|
||||
:return:
|
||||
"""
|
||||
if _IS_TORCHMETRICS_AVAILABLE:
|
||||
return isinstance(metric, torchmetrics_Metric)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _is_allennlp_metric(metric: Any) -> bool:
|
||||
"""
|
||||
检查输入的对象是否为allennlp对象
|
||||
|
||||
:param metric:
|
||||
:return:
|
||||
"""
|
||||
if _IS_ALLENNLP_AVAILABLE:
|
||||
return isinstance(metric, allennlp_Metric)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _is_paddle_metric(metric: Any) -> bool:
|
||||
"""
|
||||
检查输入的对象是否为allennlp对象
|
||||
|
||||
:param metric:
|
||||
:return:
|
||||
"""
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
return isinstance(metric, paddle_Metric)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
|
||||
"""
|
||||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
|
||||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。
|
||||
|
||||
:param metric: metric对象
|
||||
:param fn: 作用于 metric 的 accumulate 方法的返回值
|
||||
:param method_name: 一般来说,对于
|
||||
:return: metric
|
||||
"""
|
||||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
|
||||
f"Parameter `metric` must have a {method_name} function."
|
||||
assert callable(fn), "Parameter `fn` must be callable."
|
||||
|
||||
func = getattr(metric, method_name)
|
||||
|
||||
@wraps(func)
|
||||
def wrap_method(*args, **kwargs):
|
||||
res = func(*args, **kwargs)
|
||||
return fn(res)
|
||||
|
||||
wrap_method.__wrapped_by_func_post_proc__ = True
|
||||
setattr(metric, method_name, wrap_method)
|
||||
return metric
|
||||
|
||||
|
||||
class AggregateMethodError(BaseException):
|
||||
def __init__(self, should_have_aggregate_method, only_warn=False):
|
||||
super(AggregateMethodError, self).__init__(self)
|
||||
self.should_have_aggregate_method = should_have_aggregate_method
|
||||
self.only_warn = only_warn
|
Loading…
Reference in New Issue
Block a user