1、为TorchDataLoader添加get_batch_indices函数 2、在设置sampler后将shuffle设置为False

This commit is contained in:
x54-729 2022-05-03 08:41:53 +00:00
parent 175ced3905
commit aff84e5955

View File

@ -3,7 +3,7 @@ __all__ = [
'prepare_torch_dataloader'
]
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):
if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
def get_batch_indices(self) -> List[int]:
"""
获取当前 batch idx
:return:
"""
return self.cur_batch_indices
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,