mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
36149a57b0
@ -19,7 +19,7 @@ from fastNLP.core.utils import (
|
||||
paddle_move_data_to_device,
|
||||
is_in_paddle_dist,
|
||||
)
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler
|
||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
return dataloader
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
sampler = UnrepeatedDistributedSampler(
|
||||
sampler = UnrepeatedSampler(
|
||||
dataset=dataloader.dataset,
|
||||
shuffle=shuffle,
|
||||
seed=int(os.environ.get("FASTNLP_SEED", 0))
|
||||
|
@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import (
|
||||
ForwardState,
|
||||
_MODE_PARAMETER,
|
||||
reset_seed,
|
||||
replace_sampler
|
||||
replace_sampler,
|
||||
replace_batch_sampler
|
||||
)
|
||||
from fastNLP.core.drivers.utils import distributed_open_proc
|
||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
|
||||
@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver):
|
||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None,
|
||||
reproducible: bool = False):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
|
||||
if isinstance(dist, ReproducibleIterator):
|
||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||
dist = re_instantiate_sampler(dist)
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# trainer, evaluator
|
||||
@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver):
|
||||
elif dist == "dist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||
if isinstance(args.sampler, ReproducibleIterator):
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
batch_sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
elif isinstance(args.sampler, ReproducibleIterator):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver):
|
||||
shuffle=args.shuffle,
|
||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
|
||||
)
|
||||
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
@ -487,8 +509,11 @@ class TorchDDPDriver(TorchDriver):
|
||||
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
# todo @yh,补充 unrepeatdist 相关内容;
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
sampler = UnrepeatedDistributedSampler(
|
||||
|
||||
# todo 判断 batch_sampler;
|
||||
sampler = UnrepeatedSampler(
|
||||
dataset=args.dataset,
|
||||
shuffle=args.shuffle,
|
||||
)
|
||||
|
@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver):
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
|
||||
reproducible: bool = False):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleIterator):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
if reproducible:
|
||||
|
@ -244,8 +244,34 @@ class TorchDriver(Driver):
|
||||
logger.debug("Load model.")
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
"""
|
||||
使用场景:
|
||||
|
||||
现在sampler/batch_sampler的替换情况:
|
||||
1. 单卡多卡;
|
||||
2. 是否断点重训;
|
||||
|
||||
3. 用户通过 dist 传入;
|
||||
4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler;
|
||||
|
||||
应当确定的规则:
|
||||
batchsampler 优先级高于 sampler;
|
||||
|
||||
单卡:
|
||||
不是断点重训:
|
||||
用户自己
|
||||
|
||||
|
||||
用户不自己在外面直接替换 sampler 或者 batchsampler
|
||||
1. 单卡:
|
||||
|
||||
"""
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
|
||||
# todo 先捋一下;
|
||||
# batch_sampler = dataloader_args.batch_sampler
|
||||
# if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)):
|
||||
|
||||
sampler = dataloader_args.sampler
|
||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
|
||||
# 说明这里需要使用 ReproduceSampler 来弄一下了
|
||||
|
@ -3,19 +3,24 @@ __all__ = [
|
||||
'SortedSampler',
|
||||
'ConstTokenNumSampler',
|
||||
'ConstantTokenNumSampler',
|
||||
'UnrepeatedDistributedSampler',
|
||||
|
||||
'MixSampler',
|
||||
'InnerSampler',
|
||||
'DopedSampler',
|
||||
'MixSequentialSampler',
|
||||
'PollingSampler',
|
||||
|
||||
'ReproducibleIterator',
|
||||
'RandomSampler',
|
||||
're_instantiate_sampler'
|
||||
|
||||
're_instantiate_sampler',
|
||||
|
||||
'UnrepeatedSampler',
|
||||
"UnrepeatedSortedSampler"
|
||||
]
|
||||
|
||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
|
||||
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
|
||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler
|
||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
|
||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict
|
||||
|
||||
__all__ = [
|
||||
'MixSampler',
|
||||
'InnerSampler',
|
||||
'DopedSampler',
|
||||
'MixSequentialSampler',
|
||||
'PollingSampler'
|
||||
|
@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler):
|
||||
return type(sampler)(**all_attributes)
|
||||
|
||||
|
||||
|
||||
class ReproducibleIterator:
|
||||
"""
|
||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
|
||||
|
@ -7,7 +7,6 @@ __all__ = [
|
||||
"SortedSampler",
|
||||
'ConstTokenNumSampler',
|
||||
"ConstantTokenNumSampler",
|
||||
"UnrepeatedDistributedSampler",
|
||||
]
|
||||
|
||||
from itertools import chain
|
||||
@ -18,7 +17,7 @@ import numpy as np
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.utils.data import SequentialSampler, Sampler, RandomSampler
|
||||
from torch.utils.data import Sampler
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
|
||||
|
||||
@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets):
|
||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
|
||||
bucket_data[bucket_id].append(idx)
|
||||
return bucket_data
|
||||
|
||||
|
||||
class UnrepeatedDistributedSampler:
|
||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0):
|
||||
"""
|
||||
考虑在多卡evaluate的场景下,不能重复sample。
|
||||
|
||||
:param dataset:
|
||||
:param shuffle:
|
||||
:param seed:
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
# 多卡的相关的参数
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
self.epoch = -1
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
|
||||
:return:
|
||||
"""
|
||||
num_common = len(self.dataset)//self.num_replicas
|
||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
|
||||
return self.num_samples
|
||||
|
||||
def __iter__(self):
|
||||
r"""
|
||||
当前使用num_consumed_samples做法会在交替使用的时候遇到问题;
|
||||
Example:
|
||||
>>> sampler = RandomSampler()
|
||||
>>> iter1 = iter(sampler)
|
||||
>>> iter2 = iter(sampler)
|
||||
>>> next(iter1)
|
||||
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化
|
||||
"""
|
||||
|
||||
indices = self.generate_indices()
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == len(self)
|
||||
|
||||
for index in indices:
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
"""
|
||||
生成随机序列
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(len(self.dataset)))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
return indices
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self.epoch = epoch
|
||||
|
||||
def set_distributed(self, num_replicas, rank):
|
||||
"""
|
||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
|
||||
|
||||
:param num_replicas:
|
||||
:param rank:
|
||||
:return:
|
||||
"""
|
||||
assert num_replicas>0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0<=rank<num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
return self
|
114
fastNLP/core/samplers/unrepeated_sampler.py
Normal file
114
fastNLP/core/samplers/unrepeated_sampler.py
Normal file
@ -0,0 +1,114 @@
|
||||
__all__ = [
|
||||
'UnrepeatedSortedSampler',
|
||||
'UnrepeatedSampler'
|
||||
]
|
||||
|
||||
from typing import List, Union
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class UnrepeatedSampler:
|
||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
|
||||
"""
|
||||
考虑在多卡evaluate的场景下,不能重复sample。
|
||||
|
||||
:param dataset:
|
||||
:param shuffle:
|
||||
:param seed:
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
# 多卡的相关的参数
|
||||
self.num_replicas = kwargs.get('num_replicas', 1)
|
||||
self.rank = kwargs.get('rank', 0)
|
||||
self.epoch = kwargs.get('epoch', -1)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
|
||||
:return:
|
||||
"""
|
||||
num_common = len(self.dataset)//self.num_replicas
|
||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
|
||||
return self.num_samples
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.generate_indices()
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == len(self)
|
||||
|
||||
for index in indices:
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
"""
|
||||
生成随机序列
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(len(self.dataset)))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
return indices
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self.epoch = epoch
|
||||
|
||||
def set_distributed(self, num_replicas, rank):
|
||||
"""
|
||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
|
||||
|
||||
:param num_replicas:
|
||||
:param rank:
|
||||
:return:
|
||||
"""
|
||||
assert num_replicas>0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0<=rank<num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class UnrepeatedSortedSampler(UnrepeatedSampler):
|
||||
def __init__(self, dataset, length:Union[str, List], seed: int = 0):
|
||||
"""
|
||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的
|
||||
batch 数量不完全一致。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
|
||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
|
||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
|
||||
:param seed: 设置的随机数种子
|
||||
:param kwargs: fastNLP 保留使用
|
||||
"""
|
||||
super().__init__(dataset=dataset, shuffle=False, seed=seed)
|
||||
if isinstance(dataset, DataSet):
|
||||
length = dataset.get_field(length)
|
||||
if not isinstance(length[0], int):
|
||||
length = list(map(len, length))
|
||||
else:
|
||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
|
||||
"the length parameter can only be List[int]"
|
||||
|
||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
|
||||
|
||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
|
||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
return self.sorted_indices
|
64
tests/core/samplers/test_unrepeated_sampler.py
Normal file
64
tests/core/samplers/test_unrepeated_sampler.py
Normal file
@ -0,0 +1,64 @@
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler
|
||||
|
||||
|
||||
class DatasetWithVaryLength:
|
||||
def __init__(self, num_of_data=100):
|
||||
self.data = list(range(num_of_data))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.data[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class TestUnrepeatedSampler:
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
def test_single(self, shuffle):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = UnrepeatedSampler(data, shuffle)
|
||||
indexes = set(sampler)
|
||||
assert indexes==set(range(num_of_data))
|
||||
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
@pytest.mark.parametrize('shuffle', [False, True])
|
||||
def test_multi(self, num_replica, num_of_data, shuffle):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle)
|
||||
sampler.set_distributed(num_replica, rank=i)
|
||||
samplers.append(sampler)
|
||||
|
||||
indexes = set(chain(*samplers))
|
||||
assert indexes==set(range(num_of_data))
|
||||
|
||||
|
||||
class TestUnrepeatedSortedSampler:
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
def test_single(self, shuffle):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = UnrepeatedSortedSampler(data, length=data.data)
|
||||
indexes = list(sampler)
|
||||
assert indexes==list(range(num_of_data-1, -1, -1))
|
||||
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
@pytest.mark.parametrize('shuffle', [False, True])
|
||||
def test_multi(self, num_replica, num_of_data, shuffle):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
|
||||
sampler.set_distributed(num_replica, rank=i)
|
||||
samplers.append(sampler)
|
||||
|
||||
indexes = set(chain(*samplers))
|
||||
assert indexes==set(range(num_of_data))
|
Loading…
Reference in New Issue
Block a user