From aff84e5955de927ab40ac25cac5eb5d656455468 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 08:41:53 +0000 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E4=B8=BATorchDataLoader=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0get=5Fbatch=5Findices=E5=87=BD=E6=95=B0=202=E3=80=81?= =?UTF-8?q?=E5=9C=A8=E8=AE=BE=E7=BD=AEsampler=E5=90=8E=E5=B0=86shuffle?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE=E4=B8=BAFalse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 12356074..d008d4ad 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,7 +3,7 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator @@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle=False super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, @@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + def get_batch_indices(self) -> List[int]: + """ + 获取当前 batch 的 idx + + :return: + """ + return self.cur_batch_indices + def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1,