修复no_sync的bug

This commit is contained in:
yh_cc 2022-04-20 00:37:19 +08:00
parent bfa9920b16
commit 6b4d4502db
3 changed files with 6 additions and 4 deletions

View File

@ -31,7 +31,7 @@ class Saver:
folder = Path.cwd()
folder = Path(folder)
if not folder.exists():
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!")
folder.mkdir(parents=True, exist_ok=True)
elif folder.is_file():
raise ValueError("Parameter `folder` should be a directory instead of a file.")

View File

@ -36,7 +36,8 @@ class TrainBatchLoop(Loop):
raise e
trainer.on_train_batch_begin(batch, indices)
self.batch_step_fn(trainer, batch)
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync
self.batch_step_fn(trainer, batch)
trainer.global_forward_batches += 1
trainer.batch_idx_in_epoch += 1

View File

@ -696,8 +696,9 @@ class Trainer(TrainerEventTrigger):
self.on_before_backward(outputs)
loss = self.extract_loss_from_outputs(outputs)
loss = loss / self.accumulation_steps
with self.get_no_sync_context():
self.driver.backward(loss)
# with self.get_no_sync_context():
# self.driver.backward(loss)
self.driver.backward(loss)
self.on_after_backward()
def zero_grad(self):