mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
[bugfix]修复fitlogcallback在disttrainner中无法添加dev_data 的问题 (#348)
fix the distTrainer dev_data
This commit is contained in:
parent
84776696cd
commit
f17343e19b
@ -177,8 +177,13 @@ class DistTrainer():
|
||||
self.batch_size = self.world_size * self.batch_size_per_gpu
|
||||
self.n_steps = self._get_n_steps()
|
||||
|
||||
self.dev_data = dev_data
|
||||
self.metrics = metrics
|
||||
self.test_use_tqdm = True
|
||||
self.kwargs = kwargs
|
||||
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
|
||||
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
|
||||
|
||||
# for evaluation, only run eval on master proc
|
||||
if dev_data and metrics:
|
||||
cb = _TesterCallback(
|
||||
|
Loading…
Reference in New Issue
Block a user