mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
1、为TorchDataLoader添加get_batch_indices函数 2、在设置sampler后将shuffle设置为False
This commit is contained in:
parent
175ced3905
commit
aff84e5955
@ -3,7 +3,7 @@ __all__ = [
|
||||
'prepare_torch_dataloader'
|
||||
]
|
||||
|
||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
|
||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.collators import Collator
|
||||
@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):
|
||||
|
||||
if sampler is None and batch_sampler is None:
|
||||
sampler = RandomSampler(dataset, shuffle=shuffle)
|
||||
shuffle=False
|
||||
|
||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
|
||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
|
||||
@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
|
||||
else:
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
def get_batch_indices(self) -> List[int]:
|
||||
"""
|
||||
获取当前 batch 的 idx
|
||||
|
||||
:return:
|
||||
"""
|
||||
return self.cur_batch_indices
|
||||
|
||||
|
||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
|
||||
batch_size: int = 1,
|
||||
|
Loading…
Reference in New Issue
Block a user