修改了DDPDriver.set_dist_repro_dataloader 的逻辑

This commit is contained in:
YWMditto 2022-04-11 18:58:55 +08:00
parent 9675eae798
commit d39fa5137e
4 changed files with 39 additions and 64 deletions

View File

@ -49,6 +49,9 @@ class Driver(ABC):
不同 gpu 上出现重复 'unrepeatdist' 表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据允许不同 gpu batch 的数量不一致其中 trainer kwargs 的参数 `use_dist_sampler` True 该值为 "dist"
否则为 None evaluator 中的 kwargs 的参数 `use_dist_sampler` True 该值为 "unrepeatdist"否则为 None
注意当 dist ReproducibleIterator, ReproducibleBatchSampler 是断点重训加载时 driver.load 函数在调用
dist str 或者 None trainer 在初始化时调用该函数
:param reproducible: 如果为 False 不要做任何考虑如果为 True 需要保证返回的 dataloader 可以保存当前的迭代状态使得
可以可以加载
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) 此外

View File

@ -448,31 +448,26 @@ class TorchDDPDriver(TorchDriver):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None,
reproducible: bool = False):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleIterator):
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist)
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleIterator):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
return dataloader
# trainer
elif dist == "dist":
@ -506,7 +501,6 @@ class TorchDDPDriver(TorchDriver):
pad=True
)
return replace_sampler(dataloader, sampler)
# evaluator
elif dist == "unrepeatdist":
# todo @yh补充 unrepeatdist 相关内容;

View File

@ -132,25 +132,29 @@ class TorchSingleDriver(TorchDriver):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
reproducible: bool = False):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
if reproducible:
args = self.get_dataloader_args(dataloader)
if isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
else:
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
batch_sampler = ReproducibleBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

View File

@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
class TorchDriver(Driver):
@ -244,47 +244,21 @@ class TorchDriver(Driver):
logger.debug("Load model.")
# 3. 恢复 sampler 的状态;
"""
使用场景
现在sampler/batch_sampler的替换情况
1. 单卡多卡
2. 是否断点重训
3. 用户通过 dist 传入
4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler
应当确定的规则
batchsampler 优先级高于 sampler
单卡
不是断点重训
用户自己
用户不自己在外面直接替换 sampler 或者 batchsampler
1. 单卡
"""
dataloader_args = self.get_dataloader_args(dataloader)
# todo 先捋一下;
# batch_sampler = dataloader_args.batch_sampler
# if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)):
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
"`ReproducibleBatchSampler` or `ReproducibleIterator`.")
else:
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch