mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
move model to device in DistTrainer
This commit is contained in:
parent
972185dc6c
commit
e1ed6f16e4
@ -165,11 +165,11 @@ class DistTrainer:
|
||||
self.grad_scaler = grad_scaler
|
||||
|
||||
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
|
||||
|
||||
# init DataParallel
|
||||
if isinstance(model, DDP):
|
||||
self.ddp_model = model
|
||||
else:
|
||||
model.to(self.device)
|
||||
if parse_version(torch.__version__)>=parse_version('1.1'):
|
||||
self.ddp_model = DDP(model, device_ids=[self.local_rank],
|
||||
output_device=self.local_rank,
|
||||
@ -182,7 +182,6 @@ class DistTrainer:
|
||||
self._forward_func = self.model.forward
|
||||
self.model.to(self.device)
|
||||
|
||||
|
||||
optimizer = self._get_optimizer(optimizer)
|
||||
self.optimizer = optimizer
|
||||
if isinstance(self.train_data, DataSet):
|
||||
|
Loading…
Reference in New Issue
Block a user