mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
metric bug fix
This commit is contained in:
parent
a90a62ab9b
commit
50f1c28b74
@ -115,10 +115,10 @@ class MetricBase(object):
|
||||
|
||||
|
||||
class AccuracyMetric(MetricBase):
|
||||
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None):
|
||||
def __init__(self, input=None, targets=None, masks=None, seq_lens=None):
|
||||
super().__init__()
|
||||
|
||||
self._init_param_map(predictions=predictions, targets=targets,
|
||||
self._init_param_map(input=input, targets=targets,
|
||||
masks=masks, seq_lens=seq_lens)
|
||||
|
||||
self.total = 0
|
||||
@ -138,7 +138,7 @@ class AccuracyMetric(MetricBase):
|
||||
:return: dict({'acc': float})
|
||||
"""
|
||||
if not isinstance(input, torch.Tensor):
|
||||
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
|
||||
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
|
||||
f"got {type(input)}.")
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
|
||||
@ -157,9 +157,9 @@ class AccuracyMetric(MetricBase):
|
||||
if input.size()==targets.size():
|
||||
pass
|
||||
elif len(input.size())==len(targets.size())+1:
|
||||
predictions = input.argmax(dim=-1)
|
||||
input = input.argmax(dim=-1)
|
||||
else:
|
||||
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with "
|
||||
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with "
|
||||
f"size:{input.size()}, targets should with size: {input.size()} or "
|
||||
f"{input.size()[:-1]}, got {targets.size()}.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user