mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
fix bug in SeqLabelTester
This commit is contained in:
parent
d6ef132207
commit
762a559fab
@ -59,6 +59,8 @@ class BaseTester(object):
|
||||
self.batch_output.append(prediction)
|
||||
if self.save_loss:
|
||||
self.eval_history.append(eval_results)
|
||||
if step % n_print == 0:
|
||||
print('[test step: {:>4}]'.format(step))
|
||||
step += 1
|
||||
|
||||
def prepare_input(self, data_path):
|
||||
@ -134,7 +136,7 @@ class SeqLabelTester(BaseTester):
|
||||
results = torch.Tensor(prediction).view(-1,)
|
||||
# make sure "results" is in the same device as "truth"
|
||||
results = results.to(truth)
|
||||
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
|
||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
|
||||
return [loss.data, accuracy.data]
|
||||
|
||||
def metrics(self):
|
||||
@ -153,7 +155,6 @@ class SeqLabelTester(BaseTester):
|
||||
def make_batch(self, iterator, data):
|
||||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)
|
||||
|
||||
|
||||
class ClassificationTester(BaseTester):
|
||||
"""Tester for classification."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user