mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
move model to device in DistTrainer
This commit is contained in:
parent
972185dc6c
commit
e1ed6f16e4
@ -165,11 +165,11 @@ class DistTrainer:
|
|||||||
self.grad_scaler = grad_scaler
|
self.grad_scaler = grad_scaler
|
||||||
|
|
||||||
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
|
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
|
||||||
|
|
||||||
# init DataParallel
|
# init DataParallel
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
self.ddp_model = model
|
self.ddp_model = model
|
||||||
else:
|
else:
|
||||||
|
model.to(self.device)
|
||||||
if parse_version(torch.__version__)>=parse_version('1.1'):
|
if parse_version(torch.__version__)>=parse_version('1.1'):
|
||||||
self.ddp_model = DDP(model, device_ids=[self.local_rank],
|
self.ddp_model = DDP(model, device_ids=[self.local_rank],
|
||||||
output_device=self.local_rank,
|
output_device=self.local_rank,
|
||||||
@ -182,7 +182,6 @@ class DistTrainer:
|
|||||||
self._forward_func = self.model.forward
|
self._forward_func = self.model.forward
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
|
||||||
optimizer = self._get_optimizer(optimizer)
|
optimizer = self._get_optimizer(optimizer)
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
if isinstance(self.train_data, DataSet):
|
if isinstance(self.train_data, DataSet):
|
||||||
|
Loading…
Reference in New Issue
Block a user