mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
在linux桌面系统上Trainer中使用Tester的tqdm存在bug; 增加一个可选项使得用户可以关闭Tester的tqdm
This commit is contained in:
parent
f18ab642d7
commit
c38e8986cc
@ -569,7 +569,7 @@ class FitlogCallback(Callback):
|
|||||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
||||||
metrics=self.trainer.metrics,
|
metrics=self.trainer.metrics,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
use_tqdm=self.trainer.use_tqdm)
|
use_tqdm=self.trainer.test_use_tqdm)
|
||||||
self.testers[key] = tester
|
self.testers[key] = tester
|
||||||
fitlog.add_progress(total_steps=self.n_steps)
|
fitlog.add_progress(total_steps=self.n_steps)
|
||||||
|
|
||||||
@ -654,7 +654,7 @@ class EvaluateCallback(Callback):
|
|||||||
tester = Tester(data=data, model=self.model,
|
tester = Tester(data=data, model=self.model,
|
||||||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
|
||||||
metrics=self.trainer.metrics, verbose=0,
|
metrics=self.trainer.metrics, verbose=0,
|
||||||
use_tqdm=self.trainer.use_tqdm)
|
use_tqdm=self.trainer.test_use_tqdm)
|
||||||
self.testers[key] = tester
|
self.testers[key] = tester
|
||||||
|
|
||||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
|
def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
|
||||||
|
@ -545,6 +545,10 @@ class Trainer(object):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
self.use_tqdm = use_tqdm
|
self.use_tqdm = use_tqdm
|
||||||
|
if 'test_use_tqdm' in kwargs:
|
||||||
|
self.test_use_tqdm = kwargs.get('test_use_tqdm')
|
||||||
|
else:
|
||||||
|
self.test_use_tqdm = self.use_tqdm
|
||||||
self.pbar = None
|
self.pbar = None
|
||||||
self.print_every = abs(self.print_every)
|
self.print_every = abs(self.print_every)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
@ -555,7 +559,7 @@ class Trainer(object):
|
|||||||
batch_size=kwargs.get("dev_batch_size", self.batch_size),
|
batch_size=kwargs.get("dev_batch_size", self.batch_size),
|
||||||
device=None, # 由上面的部分处理device
|
device=None, # 由上面的部分处理device
|
||||||
verbose=0,
|
verbose=0,
|
||||||
use_tqdm=self.use_tqdm)
|
use_tqdm=self.test_use_tqdm)
|
||||||
|
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.start_time = None # start timestamp
|
self.start_time = None # start timestamp
|
||||||
|
Loading…
Reference in New Issue
Block a user