修正 paddle 多卡下替换 sampler 的逻辑及测试

This commit is contained in:
x54-729 2022-10-12 17:19:34 +08:00
parent 196c864ff2
commit 668430d33d
2 changed files with 4 additions and 1 deletions

View File

@ -491,7 +491,8 @@ class PaddleFleetDriver(PaddleDriver):
rank=self.global_rank
)
# TODO 这里暂时统一替换为 BatchSampler
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
batch_sampler = BatchSampler(dataset=args.dataset, batch_size=args.batch_size, drop_last=False)
batch_sampler.sampler = sampler
return replace_batch_sampler(dataloader, batch_sampler)
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

View File

@ -667,7 +667,9 @@ class TestSetDistReproDataloader:
@magic_argv_env_context
@recover_logger
@pytest.mark.parametrize("inherit", ([True, False]))
@pytest.mark.skip
def test_customized_sampler_dataloader(self, inherit):
# TODO 由于 paddle.io.DataLoader 没有 sampler 参数,因此 prepare_paddle_dataloader 没有 sampler这里暂时跳过
try:
logger.set_stdout('raw', level='info')
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行