metric bug fix

This commit is contained in:
yh 2018-12-02 14:29:11 +08:00
parent a90a62ab9b
commit 50f1c28b74

View File

@ -115,10 +115,10 @@ class MetricBase(object):
class AccuracyMetric(MetricBase):
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None):
def __init__(self, input=None, targets=None, masks=None, seq_lens=None):
super().__init__()
self._init_param_map(predictions=predictions, targets=targets,
self._init_param_map(input=input, targets=targets,
masks=masks, seq_lens=seq_lens)
self.total = 0
@ -138,7 +138,7 @@ class AccuracyMetric(MetricBase):
:return: dict({'acc': float})
"""
if not isinstance(input, torch.Tensor):
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(input)}.")
if not isinstance(targets, torch.Tensor):
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
@ -157,9 +157,9 @@ class AccuracyMetric(MetricBase):
if input.size()==targets.size():
pass
elif len(input.size())==len(targets.size())+1:
predictions = input.argmax(dim=-1)
input = input.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with "
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with "
f"size:{input.size()}, targets should with size: {input.size()} or "
f"{input.size()[:-1]}, got {targets.size()}.")