mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
修复no_sync的bug
This commit is contained in:
parent
bfa9920b16
commit
6b4d4502db
@ -31,7 +31,7 @@ class Saver:
|
||||
folder = Path.cwd()
|
||||
folder = Path(folder)
|
||||
if not folder.exists():
|
||||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!")
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
elif folder.is_file():
|
||||
raise ValueError("Parameter `folder` should be a directory instead of a file.")
|
||||
|
||||
|
@ -36,7 +36,8 @@ class TrainBatchLoop(Loop):
|
||||
raise e
|
||||
|
||||
trainer.on_train_batch_begin(batch, indices)
|
||||
self.batch_step_fn(trainer, batch)
|
||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
|
||||
self.batch_step_fn(trainer, batch)
|
||||
trainer.global_forward_batches += 1
|
||||
trainer.batch_idx_in_epoch += 1
|
||||
|
||||
|
@ -696,8 +696,9 @@ class Trainer(TrainerEventTrigger):
|
||||
self.on_before_backward(outputs)
|
||||
loss = self.extract_loss_from_outputs(outputs)
|
||||
loss = loss / self.accumulation_steps
|
||||
with self.get_no_sync_context():
|
||||
self.driver.backward(loss)
|
||||
# with self.get_no_sync_context():
|
||||
# self.driver.backward(loss)
|
||||
self.driver.backward(loss)
|
||||
self.on_after_backward()
|
||||
|
||||
def zero_grad(self):
|
||||
|
Loading…
Reference in New Issue
Block a user