修复了 dist 为 None 时的 set_dist_repro_dataloader 的逻辑

This commit is contained in:
YWMditto 2022-04-14 13:50:53 +08:00
parent 2f23d80ccc
commit 1452aa8f6c

View File

@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
if isinstance(args.sampler, ReproducibleSampler):
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
return dataloader
# trainer
elif dist == "dist":