move model to device in DistTrainer

This commit is contained in:
yh_cc 2021-10-15 11:01:01 +08:00
parent 972185dc6c
commit e1ed6f16e4

View File

@ -165,11 +165,11 @@ class DistTrainer:
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
self.set_grad_to_none = kwargs.get('set_grad_to_none', False) self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
# init DataParallel # init DataParallel
if isinstance(model, DDP): if isinstance(model, DDP):
self.ddp_model = model self.ddp_model = model
else: else:
model.to(self.device)
if parse_version(torch.__version__)>=parse_version('1.1'): if parse_version(torch.__version__)>=parse_version('1.1'):
self.ddp_model = DDP(model, device_ids=[self.local_rank], self.ddp_model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank, output_device=self.local_rank,
@ -182,7 +182,6 @@ class DistTrainer:
self._forward_func = self.model.forward self._forward_func = self.model.forward
self.model.to(self.device) self.model.to(self.device)
optimizer = self._get_optimizer(optimizer) optimizer = self._get_optimizer(optimizer)
self.optimizer = optimizer self.optimizer = optimizer
if isinstance(self.train_data, DataSet): if isinstance(self.train_data, DataSet):