From 762a559fab31f7d245a8163cb20cef3f8f250929 Mon Sep 17 00:00:00 2001 From: Yunfan Shao <15307130288@fudan.edu.cn> Date: Tue, 14 Aug 2018 00:10:44 +0800 Subject: [PATCH] fix bug in SeqLabelTester --- fastNLP/core/tester.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 3799eed1..1e7b654a 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -59,6 +59,8 @@ class BaseTester(object): self.batch_output.append(prediction) if self.save_loss: self.eval_history.append(eval_results) + if step % n_print == 0: + print('[test step: {:>4}]'.format(step)) step += 1 def prepare_input(self, data_path): @@ -134,7 +136,7 @@ class SeqLabelTester(BaseTester): results = torch.Tensor(prediction).view(-1,) # make sure "results" is in the same device as "truth" results = results.to(truth) - accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0] + accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] return [loss.data, accuracy.data] def metrics(self): @@ -153,7 +155,6 @@ class SeqLabelTester(BaseTester): def make_batch(self, iterator, data): return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) - class ClassificationTester(BaseTester): """Tester for classification."""