diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 12356074..d008d4ad 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,7 +3,7 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator @@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle=False super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, @@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") + def get_batch_indices(self) -> List[int]: + """ + 获取当前 batch 的 idx + + :return: + """ + return self.cur_batch_indices + def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1,