From 668430d33d4739f43b05cd8696b88af04e7b1781 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Wed, 12 Oct 2022 17:19:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=20paddle=20=E5=A4=9A?= =?UTF-8?q?=E5=8D=A1=E4=B8=8B=E6=9B=BF=E6=8D=A2=20sampler=20=E7=9A=84?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E5=8F=8A=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 3 ++- tests/core/drivers/paddle_driver/test_fleet.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index fc8c0695..6a38af5e 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -491,7 +491,8 @@ class PaddleFleetDriver(PaddleDriver): rank=self.global_rank ) # TODO 这里暂时统一替换为 BatchSampler - batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) + batch_sampler = BatchSampler(dataset=args.dataset, batch_size=args.batch_size, drop_last=False) + batch_sampler.sampler = sampler return replace_batch_sampler(dataloader, batch_sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 80d494da..08f191e6 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -667,7 +667,9 @@ class TestSetDistReproDataloader: @magic_argv_env_context @recover_logger @pytest.mark.parametrize("inherit", ([True, False])) + @pytest.mark.skip def test_customized_sampler_dataloader(self, inherit): + # TODO 由于 paddle.io.DataLoader 没有 sampler 参数,因此 prepare_paddle_dataloader 没有 sampler,这里暂时跳过 try: logger.set_stdout('raw', level='info') # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行