mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 10:48:40 +08:00
修正 paddle 多卡下替换 sampler 的逻辑及测试
This commit is contained in:
parent
196c864ff2
commit
668430d33d
@ -491,7 +491,8 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
rank=self.global_rank
|
||||
)
|
||||
# TODO 这里暂时统一替换为 BatchSampler
|
||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
|
||||
batch_sampler = BatchSampler(dataset=args.dataset, batch_size=args.batch_size, drop_last=False)
|
||||
batch_sampler.sampler = sampler
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
|
||||
|
@ -667,7 +667,9 @@ class TestSetDistReproDataloader:
|
||||
@magic_argv_env_context
|
||||
@recover_logger
|
||||
@pytest.mark.parametrize("inherit", ([True, False]))
|
||||
@pytest.mark.skip
|
||||
def test_customized_sampler_dataloader(self, inherit):
|
||||
# TODO 由于 paddle.io.DataLoader 没有 sampler 参数,因此 prepare_paddle_dataloader 没有 sampler,这里暂时跳过
|
||||
try:
|
||||
logger.set_stdout('raw', level='info')
|
||||
# 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行
|
||||
|
Loading…
Reference in New Issue
Block a user