mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
fix conflict
This commit is contained in:
commit
9e86908811
@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver
|
||||
from fastNLP.core.drivers.utils import choose_driver
|
||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME
|
||||
|
||||
|
@ -49,13 +49,13 @@ class Driver(ABC):
|
||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
|
||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||
可以可以加载。
|
||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
|
||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的
|
||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
|
||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
|
||||
"""
|
||||
if dist is None and reproducible is False:
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, Union
|
||||
from .jittor_driver import JittorDriver
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver):
|
||||
def is_distributed(self):
|
||||
return False
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
# reproducible 的相关功能暂时没有实现
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
raise NotImplementedError
|
||||
dataloader.batch_sampler = dist_sample
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
|
@ -11,11 +11,7 @@ from fastNLP.core.utils import (
|
||||
get_paddle_device_id,
|
||||
paddle_move_data_to_device,
|
||||
)
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleBatchSampler,
|
||||
ReproducibleIterator,
|
||||
re_instantiate_sampler,
|
||||
)
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
@ -141,12 +137,17 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
"""
|
||||
return paddle_move_data_to_device(batch, "gpu:0")
|
||||
|
||||
<<<<<<< HEAD
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
||||
=======
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
|
||||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
# 暂时不支持IteratorDataset
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
"FastNLP does not support `IteratorDataset` now."
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
<<<<<<< HEAD
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleIterator):
|
||||
return replace_sampler(dataloader, dist)
|
||||
@ -164,6 +165,25 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
=======
|
||||
dataloader.batch_sampler = dist
|
||||
return dataloader
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
dataloader.batch_sampler.sampler = dist
|
||||
return dataloader
|
||||
|
||||
if reproducible:
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
|
||||
return dataloader
|
||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
|
||||
return dataloader
|
||||
else:
|
||||
# TODO
|
||||
batch_sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=dataloader.batch_sampler,
|
||||
batch_size=dataloader.batch_sampler.batch_size,
|
||||
drop_last=dataloader.drop_last
|
||||
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
|
@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
|
||||
)
|
||||
from fastNLP.core.drivers.utils import distributed_open_proc
|
||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
|
||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.core.log import logger
|
||||
@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver):
|
||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None,
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
|
||||
reproducible: bool = False):
|
||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
|
||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
|
||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
||||
"control.")
|
||||
else:
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
elif dist == "dist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||
if isinstance(args.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
batch_sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
|
@ -13,7 +13,7 @@ __all__ = [
|
||||
from .torch_driver import TorchDriver
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
|
||||
@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver):
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None,
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
|
||||
reproducible: bool = False):
|
||||
|
||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleSampler):
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
elif isinstance(args.sampler, ReproducibleSampler):
|
||||
@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver):
|
||||
return replace_sampler(dataloader, sampler)
|
||||
|
||||
if reproducible:
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
|
@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
|
||||
|
||||
class TorchDriver(Driver):
|
||||
@ -183,9 +183,9 @@ class TorchDriver(Driver):
|
||||
|
||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
|
||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
|
||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`;
|
||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif dataloader_args.sampler:
|
||||
sampler = dataloader_args.sampler
|
||||
@ -245,15 +245,14 @@ class TorchDriver(Driver):
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
|
||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
|
||||
sampler = dataloader_args.sampler
|
||||
elif self.is_distributed():
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
|
||||
"`RandomBatchSampler` or `ReproducibleIterator`.")
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
|
||||
else:
|
||||
sampler = RandomBatchSampler(
|
||||
sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
@ -263,7 +262,7 @@ class TorchDriver(Driver):
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||
if not isinstance(sampler, RandomBatchSampler):
|
||||
if not isinstance(sampler, ReproducibleBatchSampler):
|
||||
if dataloader_args.drop_last:
|
||||
batch_idx_in_epoch = len(
|
||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
||||
|
@ -19,6 +19,10 @@ __all__ = [
|
||||
"UnrepeatedSortedSampler",
|
||||
"UnrepeatedSequentialSampler",
|
||||
|
||||
"RandomBatchSampler",
|
||||
"BucketedBatchSampler",
|
||||
"ReproducibleBatchSampler",
|
||||
|
||||
"re_instantiate_sampler",
|
||||
"conversion_between_reproducible_and_unrepeated_sampler"
|
||||
]
|
||||
@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre
|
||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
|
||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
|
||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler
|
||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
|
||||
|
||||
|
@ -17,6 +17,9 @@ from abc import abstractmethod
|
||||
|
||||
|
||||
class ReproducibleBatchSampler:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
|
||||
@ -41,6 +44,10 @@ class ReproducibleBatchSampler:
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
@property
|
||||
def batch_idx_in_epoch(self):
|
||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
|
||||
|
||||
|
||||
class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
||||
@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
|
||||
:param kwargs: fastNLP 内部使用。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.batch_sampler = batch_sampler
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
|
Loading…
Reference in New Issue
Block a user