mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
在linux桌面系统上Trainer中使用Tester的tqdm存在bug; 增加一个可选项使得用户可以关闭Tester的tqdm
This commit is contained in:
parent
f18ab642d7
commit
c38e8986cc
@ -569,7 +569,7 @@ class FitlogCallback(Callback):
|
||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
||||
metrics=self.trainer.metrics,
|
||||
verbose=0,
|
||||
use_tqdm=self.trainer.use_tqdm)
|
||||
use_tqdm=self.trainer.test_use_tqdm)
|
||||
self.testers[key] = tester
|
||||
fitlog.add_progress(total_steps=self.n_steps)
|
||||
|
||||
@ -654,7 +654,7 @@ class EvaluateCallback(Callback):
|
||||
tester = Tester(data=data, model=self.model,
|
||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
||||
metrics=self.trainer.metrics, verbose=0,
|
||||
use_tqdm=self.trainer.use_tqdm)
|
||||
use_tqdm=self.trainer.test_use_tqdm)
|
||||
self.testers[key] = tester
|
||||
|
||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
|
||||
|
@ -545,6 +545,10 @@ class Trainer(object):
|
||||
self.logger = logger
|
||||
|
||||
self.use_tqdm = use_tqdm
|
||||
if 'test_use_tqdm' in kwargs:
|
||||
self.test_use_tqdm = kwargs.get('test_use_tqdm')
|
||||
else:
|
||||
self.test_use_tqdm = self.use_tqdm
|
||||
self.pbar = None
|
||||
self.print_every = abs(self.print_every)
|
||||
self.kwargs = kwargs
|
||||
@ -555,7 +559,7 @@ class Trainer(object):
|
||||
batch_size=kwargs.get("dev_batch_size", self.batch_size),
|
||||
device=None, # 由上面的部分处理device
|
||||
verbose=0,
|
||||
use_tqdm=self.use_tqdm)
|
||||
use_tqdm=self.test_use_tqdm)
|
||||
|
||||
self.step = 0
|
||||
self.start_time = None # start timestamp
|
||||
|
Loading…
Reference in New Issue
Block a user