[bugfix]修复fitlogcallback在disttrainner中无法添加dev_data 的问题 (#348)

fix the distTrainer dev_data
This commit is contained in:
ROGERDJQ 2020-12-16 18:04:22 +08:00 committed by GitHub
parent 84776696cd
commit f17343e19b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -177,8 +177,13 @@ class DistTrainer():
self.batch_size = self.world_size * self.batch_size_per_gpu
self.n_steps = self._get_n_steps()
self.dev_data = dev_data
self.metrics = metrics
self.test_use_tqdm = True
self.kwargs = kwargs
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
# for evaluation, only run eval on master proc
if dev_data and metrics:
cb = _TesterCallback(