fix bug in SeqLabelTester

This commit is contained in:
Yunfan Shao 2018-08-14 00:10:44 +08:00
parent d6ef132207
commit 762a559fab

View File

@ -59,6 +59,8 @@ class BaseTester(object):
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)
if step % n_print == 0:
print('[test step: {:>4}]'.format(step))
step += 1
def prepare_input(self, data_path):
@ -134,7 +136,7 @@ class SeqLabelTester(BaseTester):
results = torch.Tensor(prediction).view(-1,)
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [loss.data, accuracy.data]
def metrics(self):
@ -153,7 +155,6 @@ class SeqLabelTester(BaseTester):
def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)
class ClassificationTester(BaseTester):
"""Tester for classification."""