在linux桌面系统上Trainer中使用Tester的tqdm存在bug; 增加一个可选项使得用户可以关闭Tester的tqdm

This commit is contained in:
yh_cc 2019-08-22 19:20:24 +08:00
parent f18ab642d7
commit c38e8986cc
2 changed files with 7 additions and 3 deletions

View File

@ -569,7 +569,7 @@ class FitlogCallback(Callback):
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics, metrics=self.trainer.metrics,
verbose=0, verbose=0,
use_tqdm=self.trainer.use_tqdm) use_tqdm=self.trainer.test_use_tqdm)
self.testers[key] = tester self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps) fitlog.add_progress(total_steps=self.n_steps)
@ -654,7 +654,7 @@ class EvaluateCallback(Callback):
tester = Tester(data=data, model=self.model, tester = Tester(data=data, model=self.model,
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics, verbose=0, metrics=self.trainer.metrics, verbose=0,
use_tqdm=self.trainer.use_tqdm) use_tqdm=self.trainer.test_use_tqdm)
self.testers[key] = tester self.testers[key] = tester
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): def on_valid_end(self, eval_result, metric_key, optimizer, better_result):

View File

@ -545,6 +545,10 @@ class Trainer(object):
self.logger = logger self.logger = logger
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
if 'test_use_tqdm' in kwargs:
self.test_use_tqdm = kwargs.get('test_use_tqdm')
else:
self.test_use_tqdm = self.use_tqdm
self.pbar = None self.pbar = None
self.print_every = abs(self.print_every) self.print_every = abs(self.print_every)
self.kwargs = kwargs self.kwargs = kwargs
@ -555,7 +559,7 @@ class Trainer(object):
batch_size=kwargs.get("dev_batch_size", self.batch_size), batch_size=kwargs.get("dev_batch_size", self.batch_size),
device=None, # 由上面的部分处理device device=None, # 由上面的部分处理device
verbose=0, verbose=0,
use_tqdm=self.use_tqdm) use_tqdm=self.test_use_tqdm)
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp