From c38e8986cc5c692df17ba35c0eeaf59cb36383fc Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 22 Aug 2019 19:20:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8linux=E6=A1=8C=E9=9D=A2=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E4=B8=8ATrainer=E4=B8=AD=E4=BD=BF=E7=94=A8Tester?= =?UTF-8?q?=E7=9A=84tqdm=E5=AD=98=E5=9C=A8bug;=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=B8=80=E4=B8=AA=E5=8F=AF=E9=80=89=E9=A1=B9=E4=BD=BF=E5=BE=97?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E5=8F=AF=E4=BB=A5=E5=85=B3=E9=97=ADTester?= =?UTF-8?q?=E7=9A=84tqdm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 4 ++-- fastNLP/core/trainer.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 4ba4b945..24b42b6e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -569,7 +569,7 @@ class FitlogCallback(Callback): batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), metrics=self.trainer.metrics, verbose=0, - use_tqdm=self.trainer.use_tqdm) + use_tqdm=self.trainer.test_use_tqdm) self.testers[key] = tester fitlog.add_progress(total_steps=self.n_steps) @@ -654,7 +654,7 @@ class EvaluateCallback(Callback): tester = Tester(data=data, model=self.model, batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), metrics=self.trainer.metrics, verbose=0, - use_tqdm=self.trainer.use_tqdm) + use_tqdm=self.trainer.test_use_tqdm) self.testers[key] = tester def on_valid_end(self, eval_result, metric_key, optimizer, better_result): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2c52d104..290a89c1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -545,6 +545,10 @@ class Trainer(object): self.logger = logger self.use_tqdm = use_tqdm + if 'test_use_tqdm' in kwargs: + self.test_use_tqdm = kwargs.get('test_use_tqdm') + else: + self.test_use_tqdm = self.use_tqdm self.pbar = None self.print_every = abs(self.print_every) self.kwargs = kwargs @@ -555,7 +559,7 @@ class Trainer(object): batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device verbose=0, - use_tqdm=self.use_tqdm) + use_tqdm=self.test_use_tqdm) self.step = 0 self.start_time = None # start timestamp