fix conflict

This commit is contained in:
x54-729 2022-04-11 14:36:23 +00:00
commit 9e86908811
9 changed files with 64 additions and 33 deletions

View File

@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.envs import rank_zero_call
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME

View File

@ -49,13 +49,13 @@ class Driver(ABC):
不同 gpu 上出现重复 'unrepeatdist' 表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据允许不同 gpu batch 的数量不一致其中 trainer kwargs 的参数 `use_dist_sampler` True 该值为 "dist"
否则为 None evaluator 中的 kwargs 的参数 `use_dist_sampler` True 该值为 "unrepeatdist"否则为 None
注意当 dist ReproducibleIterator, RandomBatchSampler 是断点重训加载时 driver.load 函数在调用
注意当 dist ReproducibleSampler, ReproducibleBatchSampler 是断点重训加载时 driver.load 函数在调用
dist str 或者 None trainer 在初始化时调用该函数
:param reproducible: 如果为 False 不要做任何考虑如果为 True 需要保证返回的 dataloader 可以保存当前的迭代状态使得
可以可以加载
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) 此外
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 如果 dist 为空 reproducible False可直接返回原对象
"""
if dist is None and reproducible is False:

View File

@ -3,7 +3,7 @@ from typing import Dict, Union
from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
if _NEED_IMPORT_JITTOR:
import jittor
@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self):
return False
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist, ReproducibleSampler):

View File

@ -11,11 +11,7 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleIterator,
re_instantiate_sampler,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE:
@ -141,12 +137,17 @@ class PaddleSingleDriver(PaddleDriver):
"""
return paddle_move_data_to_device(batch, "gpu:0")
<<<<<<< HEAD
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
=======
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist, ReproducibleBatchSampler):
<<<<<<< HEAD
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
return replace_sampler(dataloader, dist)
@ -164,6 +165,25 @@ class PaddleSingleDriver(PaddleDriver):
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
=======
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist
return dataloader
if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
>>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1
)
return replace_batch_sampler(dataloader, batch_sampler)
else:

View File

@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
reproducible: bool = False):
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, RandomBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver):
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为
if isinstance(args.batch_sampler, RandomBatchSampler):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,

View File

@ -13,7 +13,7 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.log import logger
@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver):
else:
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, RandomBatchSampler):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, RandomBatchSampler):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver):
return replace_sampler(dataloader, sampler)
if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
class TorchDriver(Driver):
@ -183,9 +183,9 @@ class TorchDriver(Driver):
# 1. sampler 的状态,因为我们支持 resume training即精确恢复到具体的一个 batch
# 首先 pytorch 的 DataLoader 一定会有 sampler另一方面我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
@ -245,15 +245,14 @@ class TorchDriver(Driver):
# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
"`RandomBatchSampler` or `ReproducibleIterator`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
sampler = ReproducibleBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
@ -263,7 +262,7 @@ class TorchDriver(Driver):
# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if not isinstance(sampler, RandomBatchSampler):
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size

View File

@ -19,6 +19,10 @@ __all__ = [
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",
"RandomBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",
"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
]
@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler

View File

@ -17,6 +17,9 @@ from abc import abstractmethod
class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass
@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
@ -41,6 +44,10 @@ class ReproducibleBatchSampler:
def set_epoch(self, epoch):
pass
@property
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")
class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample 是否丢掉
:param kwargs: fastNLP 内部使用
"""
super().__init__()
self.batch_sampler = batch_sampler
self.batch_size = batch_size
self.drop_last = drop_last