mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
修复了 dist 为 None 时的 set_dist_repro_dataloader 的逻辑
This commit is contained in:
parent
2f23d80ccc
commit
1452aa8f6c
@ -471,12 +471,11 @@ class TorchDDPDriver(TorchDriver):
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
||||
"control.")
|
||||
else:
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_sampler(dataloader, dist)
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
return replace_batch_sampler(dataloader, re_instantiate_sampler(args.batch_sampler))
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
return replace_sampler(dataloader, re_instantiate_sampler(args.sampler))
|
||||
return dataloader
|
||||
# trainer
|
||||
elif dist == "dist":
|
||||
|
Loading…
Reference in New Issue
Block a user