删除了driver的replace_sampler替换为set_dist_repro_dataloader; 同时修改 driver.load/driver.save 函数

This commit is contained in:
yh_cc 2022-04-10 00:08:19 +08:00
parent 5b54a0cd73
commit 8e4abf2aa5
15 changed files with 146 additions and 130 deletions

View File

@ -124,11 +124,7 @@ class Evaluator:
self.dataloaders = {} self.dataloaders = {}
for name, dl in dataloaders.items(): # 替换为正确的 sampler for name, dl in dataloaders.items(): # 替换为正确的 sampler
dl = self.driver.replace_sampler( dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
dataloader=dl,
dist_sampler=self._dist_sampler,
reproducible=False
)
self.dataloaders[name] = dl self.dataloaders[name] = dl
self.progress_bar = kwargs.get('progress_bar', 'auto') self.progress_bar = kwargs.get('progress_bar', 'auto')

View File

@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger):
self.dataloader = self.train_dataloader self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader) self.driver.set_deterministic_dataloader(self.dataloader)
self.dataloader = self.driver.replace_sampler( self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
dataloader=self.train_dataloader, reproducible=self.callback_manager.has_trainer_chechpoint)
dist_sampler=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint
)
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver) self.on_after_trainer_initialized(self.driver)
@ -578,22 +575,6 @@ class Trainer(TrainerEventTrigger):
else: else:
states["val_filter_state"] = None states["val_filter_state"] = None
# 4. sampler 的状态,因为我们支持 resume training即精确恢复到具体的一个 batch
# 首先 pytorch 的 DataLoader 一定会有 sampler另一方面我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`
dataloader_args = self.driver.get_dataloader_args(self.dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
if isinstance(folder, str): if isinstance(folder, str):
folder = Path(folder) folder = Path(folder)
@ -601,9 +582,9 @@ class Trainer(TrainerEventTrigger):
if not callable(model_save_fn): if not callable(model_save_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_save_fn)(folder) rank_zero_call(model_save_fn)(folder)
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs) self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs)
else: else:
self.driver.save(folder=folder, states=states, self.driver.save(folder=folder, dataloader=self.dataloader, states=states,
only_state_dict=only_state_dict, should_save_model=True, **kwargs) only_state_dict=only_state_dict, should_save_model=True, **kwargs)
self.driver.barrier() self.driver.barrier()
@ -616,9 +597,6 @@ class Trainer(TrainerEventTrigger):
保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleIterator 保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleIterator
注意我们目前不支持单卡到多卡的断点重训 注意我们目前不支持单卡到多卡的断点重训
TODO:注意我们目前不支持 RandomSamplerBucketedSampler 或者 SortedSampler 之间的断点重训
因此如果用户自己需要使用 BucketedSampler那么其需要自己在 Trainer 之前初始化 BucketedSampler然后替换原始 Dataloader 中的
sampler不管其是第一次断点重训还是之后的加载的重新训练
:param folder: 保存断点重训 states 的文件地址 :param folder: 保存断点重训 states 的文件地址
:param resume_training: 是否从上次的 batch 开始训练或者只从最近的 epoch 开始训练注意如果 resume_training=True那么我们 :param resume_training: 是否从上次的 batch 开始训练或者只从最近的 epoch 开始训练注意如果 resume_training=True那么我们
@ -627,33 +605,23 @@ class Trainer(TrainerEventTrigger):
self.driver.barrier() self.driver.barrier()
if isinstance(folder, str): if isinstance(folder, str):
folder = Path(folder) folder = Path(folder)
dataloader = self.dataloader
if not resume_training:
dataloader = None
if model_load_fn is not None: if model_load_fn is not None:
if not callable(model_load_fn): if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder) rank_zero_call(model_load_fn)(folder)
states = self.driver.load(folder=folder, should_load_model=False, **kwargs) states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else: else:
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs) states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
if not resume_training: if not resume_training:
return return
# 1. 恢复 sampler 的状态; self.dataloader = states.pop('dataloader')
dataloader_args = self.driver.get_dataloader_args(self.dataloader)
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.driver.is_distributed():
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
self.driver.replace_sampler(self.dataloader, sampler)
# 2. validate filter state # 2. validate filter state
if self.evaluator is not None: if self.evaluator is not None:
@ -668,22 +636,16 @@ class Trainer(TrainerEventTrigger):
# 4. 修改 trainer_state.batch_idx_in_epoch # 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler # sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if not isinstance(sampler, ReproducibleBatchSampler): # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
if dataloader_args.drop_last: # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
else:
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler
else:
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch
# 5. 恢复所有 callback 的状态; # 5. 恢复所有 callback 的状态;
self.on_load_checkpoint(states["callback_states"]) self.on_load_checkpoint(states["callback_states"])
self.driver.barrier() self.driver.barrier()
""" 这四个函数是用来方便用户定制自己的 batch_step_fn用于替换 train_batch_loop 当中的 step 函数) 的 """ """ 这四个函数是用来方便用户定制自己的 batch_step_fn用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """
def train_step(self, batch): def train_step(self, batch):
with self.driver.auto_cast(): with self.driver.auto_cast():

View File

@ -2,7 +2,7 @@ import os
import signal import signal
import sys import sys
from typing import Any, Sequence, List, Optional, Callable, Dict, Union from typing import Any, Sequence, List, Optional, Callable, Dict, Union
from abc import ABC from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from io import BytesIO from io import BytesIO
@ -14,7 +14,6 @@ __all__ = [
from fastNLP.core.utils import nullcontext from fastNLP.core.utils import nullcontext
# todo 航总 check 一下哪一些方法需要 @abstractmethod
class Driver(ABC): class Driver(ABC):
r""" r"""
用来初始化 `Driver` 的基类所有定制的 `driver` 都需要继承此类 用来初始化 `Driver` 的基类所有定制的 `driver` 都需要继承此类
@ -32,29 +31,33 @@ class Driver(ABC):
# self._consensus_file: Optional[Union[str, Path]] = None # self._consensus_file: Optional[Union[str, Path]] = None
self._pids: Optional[List[int]] = None self._pids: Optional[List[int]] = None
@abstractmethod
def setup(self): def setup(self):
r""" r"""
该函数用来初始化训练环境例如将模型迁移到对应的设备上等 该函数用来初始化训练环境例如将模型迁移到对应的设备上等
多卡的 driver 的该函数要更为复杂一些例如其可能需要开启多进程之间的通信环境以及设置一些环境变量和其余所需要的变量值 多卡的 driver 的该函数要更为复杂一些例如其可能需要开启多进程之间的通信环境以及设置一些环境变量和其余所需要的变量值
""" """
def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False):
r""" r"""
因为一些特殊的情况需要替换 dataloader sampler而每一个 driver 中的该函数会提供该功能例如在多卡训练的中我们 根据输入的 dataloader 得到一个 支持分布式 distributed 可复现的 (reproducible) dataloader
需要将 sampler 替换为 distributed sampler以及如果用户在 Trainer 中加入了断点重训的 callback那么我们就需要将 sampler 替换
reproducible sampler
:param dataloader: trainer 中传入的原始的 dataloader :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
:param dist_sampler: 应当为一个字符串其值应当为以下之一[None, "dist", "unrepeatdist"]用于指定使用怎样的 sampler :param dist: 应当为一个字符串其值应当为以下之一[None, "dist", "unrepeatdist"] None 表示不需要考虑当前 dataloader
目前该参数被定制为分布式训练服务其中 trainer kwargs 的参数 `use_dist_sampler` True 该值为 "dist"否则为 None 切换为分布式状态 'dist' 表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的允许出现少量 sample
evaluator 中的 kwargs 的参数 `use_dist_sampler` True 该值为 "unrepeatdist"否则为 None 不同 gpu 上出现重复 'unrepeatdist' 表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler多卡 或者 batch_sampler单卡如果是单卡的 Driver 数据允许不同 gpu batch 的数量不一致其中 trainer kwargs 的参数 `use_dist_sampler` True 该值为 "dist"
并且该参数为 True表示当前正在断点重训那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler 否则为 None evaluator 中的 kwargs 的参数 `use_dist_sampler` True 该值为 "unrepeatdist"否则为 None
如果是多卡的 Driver那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler :param reproducible: 如果为 False 不要做任何考虑如果为 True 需要保证返回的 dataloader 可以保存当前的迭代状态使得
可以可以加载
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) 此外
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 如果 dist 为空 reproducible False可直接返回原对象
""" """
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.") if dist is None and reproducible is False:
return dataloader
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` "
f"function.")
def set_deterministic_dataloader(self, dataloader): def set_deterministic_dataloader(self, dataloader):
r""" r"""
@ -68,7 +71,7 @@ class Driver(ABC):
:param cur_epoch_idx: 当前是第几个 epoch :param cur_epoch_idx: 当前是第几个 epoch
""" """
@abstractmethod
def train_step(self, batch): def train_step(self, batch):
""" """
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程 通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程
@ -103,7 +106,7 @@ class Driver(ABC):
因此如果用户的 evaluator mode validate但是传入的 model 却没有实现 validate_step 函数而是实现了 test_step 函数那么 因此如果用户的 evaluator mode validate但是传入的 model 却没有实现 validate_step 函数而是实现了 test_step 函数那么
我们应当提醒用户这一行为 我们应当提醒用户这一行为
""" """
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.") raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.")
@property @property
def model(self): def model(self):
@ -234,6 +237,7 @@ class Driver(ABC):
""" """
self.optimizers = optimizers self.optimizers = optimizers
@abstractmethod
def backward(self, loss): def backward(self, loss):
""" """
实现深度学习中的反向传播过程 实现深度学习中的反向传播过程
@ -242,12 +246,14 @@ class Driver(ABC):
""" """
raise NotImplementedError("Each specific driver should implemented its own `backward` function.") raise NotImplementedError("Each specific driver should implemented its own `backward` function.")
@abstractmethod
def step(self): def step(self):
r""" r"""
实现深度学习中的参数的优化更新过程应当直接通过优化器 optimizers 来更新参数 实现深度学习中的参数的优化更新过程应当直接通过优化器 optimizers 来更新参数
""" """
raise NotImplementedError("Each specific driver should implemented its own `step` function.") raise NotImplementedError("Each specific driver should implemented its own `step` function.")
@abstractmethod
def zero_grad(self, set_to_none: bool = False): def zero_grad(self, set_to_none: bool = False):
r""" r"""
实现深度学习中的梯度的置零操作应当直接通过优化器 optimizers 来将梯度置零 实现深度学习中的梯度的置零操作应当直接通过优化器 optimizers 来将梯度置零
@ -286,6 +292,7 @@ class Driver(ABC):
def auto_cast(self, auto_cast): def auto_cast(self, auto_cast):
self._auto_cast = auto_cast self._auto_cast = auto_cast
@abstractmethod
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs):
r""" r"""
保存模型的函数注意函数 `save` 是用来进行断点重训的函数 保存模型的函数注意函数 `save` 是用来进行断点重训的函数
@ -296,6 +303,7 @@ class Driver(ABC):
""" """
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") raise NotImplementedError("Each specific driver should implemented its own `save_model` function.")
@abstractmethod
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs):
r""" r"""
加载模型的函数 filepath 中的模型加载并赋值给当前 model 加载模型的函数 filepath 中的模型加载并赋值给当前 model
@ -307,7 +315,8 @@ class Driver(ABC):
""" """
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")
def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): @abstractmethod
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
r""" r"""
断点重训的保存函数该函数会负责保存模型和 optimizers, fp16 state_dict以及模型的保存 should_save_model True 断点重训的保存函数该函数会负责保存模型和 optimizers, fp16 state_dict以及模型的保存 should_save_model True
@ -317,12 +326,14 @@ class Driver(ABC):
:param states: trainer 传入的一个字典其中已经包含了为了实现断点重训所需要保存的其它对象的状态Driver 应该只需要保存 :param states: trainer 传入的一个字典其中已经包含了为了实现断点重训所需要保存的其它对象的状态Driver 应该只需要保存
该对象即可 Driver 应该不需要理解该对象同时在 driver.load() 的时候需要将 states 返回回去load() 返回的值与这里的 该对象即可 Driver 应该不需要理解该对象同时在 driver.load() 的时候需要将 states 返回回去load() 返回的值与这里的
传入的值保持一致 传入的值保持一致
:param dataloader: 正在使用的 dataloader需要保存里面的状态使得之后可以从当前迭代的位置恢复
:param only_state_dict: 是否只保存模型的参数 should_save_model False 该参数无效 :param only_state_dict: 是否只保存模型的参数 should_save_model False 该参数无效
:param should_save_model: 是否应该保存模型如果为FalseDriver 将不负责 model 的保存 :param should_save_model: 是否应该保存模型如果为FalseDriver 将不负责 model 的保存
""" """
raise NotImplementedError("Each specific driver should implemented its own `save` function.") raise NotImplementedError("Each specific driver should implemented its own `save` function.")
def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: @abstractmethod
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
r""" r"""
断点重训的加载函数注意该函数会负责读取数据并且恢复 optimizers , fp16 state_dict 模型根据 should_load_model 断点重训的加载函数注意该函数会负责读取数据并且恢复 optimizers , fp16 state_dict 模型根据 should_load_model
其它在 Driver.save() 函数中执行的保存操作然后将一个 state 字典返回给 trainer 内容为Driver.save() 接受到的 states 其它在 Driver.save() 函数中执行的保存操作然后将一个 state 字典返回给 trainer 内容为Driver.save() 接受到的 states
@ -331,11 +342,22 @@ class Driver(ABC):
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME :param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME
如果 should_load_model 为True 如果 should_load_model 为True
:param dataloader: 当前给定 dataloader需要根据 save dataloader 状态合理设置若该值为 None 是不需要返回 'dataloader'
以及 'batch_idx_in_epoch' 这两个值
:param only_state_dict: 读取的 should_save_model False 该参数无效如果为 True 说明保存的内容为权重如果为 :param only_state_dict: 读取的 should_save_model False 该参数无效如果为 True 说明保存的内容为权重如果为
False 说明保存的是模型但也是通过当前 Driver 的模型去加载保存的模型的权重而不是使用保存的模型替换当前模型 False 说明保存的是模型但也是通过当前 Driver 的模型去加载保存的模型的权重而不是使用保存的模型替换当前模型
:param should_load_model: 是否应该加载模型如果为FalseDriver 将不负责加载模型若该参数为 True 但在保存的状态中没有 :param should_load_model: 是否应该加载模型如果为FalseDriver 将不负责加载模型若该参数为 True 但在保存的状态中没有
找到对应的模型状态则报错 找到对应的模型状态则报错
:return: 需要返回 save 函数输入的 states 内容 :return: 需要返回 save 函数输入的 states 内容
'dataloader'返回的是根据传入的 dataloader 保存的状态一起设置为合理的状态可以返回的对象与传入的dataloader是同一个
在保存与当前传入 data sample 数目不一致时报错
'batch_idx_in_epoch': int 类型的数据表明当前 epoch 进行到了进行到了第几个 batch 请注意该值不能是只能通过保存的
数据中读取的因为前后两次运行 batch_size 可能由变化该数字的原则应该符合以下等式
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size drop_last 参数给定的情况下无法改变因此
只能通过调整 batch_idx_in_epoch 这个值来使等式成立一个简单的计算原则如下
当drop_last为True等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
当drop_last为False等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)
""" """
raise NotImplementedError("Each specific driver should implemented its own `load` function.") raise NotImplementedError("Each specific driver should implemented its own `load` function.")
@ -352,6 +374,7 @@ class Driver(ABC):
""" """
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.")
@abstractmethod
def set_model_mode(self, mode: str): def set_model_mode(self, mode: str):
r""" r"""
设置模型为 `train` / `eval` 的模式目的是为切换模型训练和推理会关闭dropout等模式 设置模型为 `train` / `eval` 的模式目的是为切换模型训练和推理会关闭dropout等模式
@ -378,6 +401,7 @@ class Driver(ABC):
我们需要先将模型移到 cpu 又再移到 gpu 因此不适宜在该函数内部调用 `unwrap_model`而是将 model 作为该函数的参数 我们需要先将模型移到 cpu 又再移到 gpu 因此不适宜在该函数内部调用 `unwrap_model`而是将 model 作为该函数的参数
""" """
@abstractmethod
def move_data_to_device(self, batch): def move_data_to_device(self, batch):
r""" r"""
将数据迁移到指定的机器上batch 可能是 list 也可能 dict 或其嵌套结构 将数据迁移到指定的机器上batch 可能是 list 也可能 dict 或其嵌套结构
@ -399,17 +423,6 @@ class Driver(ABC):
仅在多分布式训练场景中有使用 仅在多分布式训练场景中有使用
""" """
@staticmethod
def get_dataloader_args(dataloader):
"""
用于从 dataloader 中抽取一些属性的值返回的dataclass中必须包含以下的key
sampler, batch_sampler, batch_size, drop_last
:param dataloader:
:return: 返回一个 dataclass其实例属性应当包括以上的各个属性并且其名字也应当与这些属性相同从而方便 trainer 或者其它对象调用
"""
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.")
def is_distributed(self) -> bool: def is_distributed(self) -> bool:
""" """
当前的 driver 实例是否是分布式的 当前的 driver 实例是否是分布式的

View File

@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver):
def test_step(self, batch): def test_step(self, batch):
return self._test_step(batch) return self._test_step(batch)
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
pass pass
def backward(self, loss): def backward(self, loss):

View File

@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self): def is_distributed(self):
return False return False
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现 # reproducible 的相关功能暂时没有实现
if isinstance(dist_sampler, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError raise NotImplementedError
dataloader.batch_sampler = dist_sample dataloader.batch_sampler = dist_sample
if isinstance(dist_sampler, ReproducibleIterator): if isinstance(dist, ReproducibleIterator):
raise NotImplementedError raise NotImplementedError
dataloader.batch_sampler.sampler = dist_sampler dataloader.batch_sampler.sampler = dist
if reproducible: if reproducible:
raise NotImplementedError raise NotImplementedError

View File

@ -316,13 +316,14 @@ class PaddleFleetDriver(PaddleDriver):
def test_step(self, batch): def test_step(self, batch):
return self._test_step(batch) return self._test_step(batch)
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset # 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now." "FastNLP does not support `IteratorDataset` now."
if isinstance(dist_sampler, ReproducibleIterator): if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist_sampler dataloader.batch_sampler.sampler = dist
return dataloader return dataloader
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 # paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
@ -334,14 +335,14 @@ class PaddleFleetDriver(PaddleDriver):
shuffle = dataloader.batch_sampler.shuffle shuffle = dataloader.batch_sampler.shuffle
# trainer, evaluator # trainer, evaluator
if dist_sampler is None: if dist is None:
if reproducible: if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
"control.") "control.")
else: else:
return dataloader return dataloader
# trainer # trainer
elif dist_sampler == "dist": elif dist == "dist":
# 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为 # 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
@ -364,7 +365,7 @@ class PaddleFleetDriver(PaddleDriver):
dataloader.batch_sampler.sampler = sampler dataloader.batch_sampler.sampler = sampler
return dataloader return dataloader
# evaluator # evaluator
elif dist_sampler == "unrepeatdist": elif dist == "unrepeatdist":
sampler = UnrepeatedDistributedSampler( sampler = UnrepeatedDistributedSampler(
dataset=dataloader.dataset, dataset=dataloader.dataset,
shuffle=shuffle, shuffle=shuffle,

View File

@ -133,15 +133,16 @@ class PaddleSingleDriver(PaddleDriver):
""" """
return paddle_move_data_to_device(batch, "gpu:0") return paddle_move_data_to_device(batch, "gpu:0")
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset # 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now." "FastNLP does not support `IteratorDataset` now."
if isinstance(dist_sampler, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist_sampler dataloader.batch_sampler = dist
return dataloader return dataloader
if isinstance(dist_sampler, ReproducibleIterator): if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist_sampler dataloader.batch_sampler.sampler = dist
return dataloader return dataloader
if reproducible: if reproducible:

View File

@ -445,21 +445,22 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch) return self._test_step(batch)
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
if isinstance(dist_sampler, ReproducibleIterator): reproducible: bool = False, sampler_or_batch_sampler=None):
if isinstance(dist, ReproducibleIterator):
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数; # 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
dist_sampler = re_instantiate_sampler(dist_sampler) dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist_sampler) return replace_sampler(dataloader, dist)
# trainer, evaluator # trainer, evaluator
if dist_sampler is None: if dist is None:
if reproducible: if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.") "control.")
else: else:
return dataloader return dataloader
# trainer # trainer
elif dist_sampler == "dist": elif dist == "dist":
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为 # 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为
if isinstance(args.sampler, ReproducibleIterator): if isinstance(args.sampler, ReproducibleIterator):
@ -485,7 +486,7 @@ class TorchDDPDriver(TorchDriver):
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
# evaluator # evaluator
elif dist_sampler == "unrepeatdist": elif dist == "unrepeatdist":
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler( sampler = UnrepeatedDistributedSampler(
dataset=args.dataset, dataset=args.dataset,

View File

@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver):
else: else:
return self._test_step(batch) return self._test_step(batch)
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False): reproducible: bool = False, sampler_or_batch_sampler=None):
if isinstance(dist_sampler, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist_sampler) return replace_batch_sampler(dataloader, dist)
elif isinstance(dist_sampler, ReproducibleIterator): elif isinstance(dist, ReproducibleIterator):
return replace_sampler(dataloader, dist_sampler) return replace_sampler(dataloader, dist)
if reproducible: if reproducible:
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)

View File

@ -50,6 +50,14 @@ class ReproducibleIterator:
class RandomSampler(ReproducibleIterator): class RandomSampler(ReproducibleIterator):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""
:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序
:param seed: 随机数种子
:param kwargs: 用户不需要使用fastNLP 内部使用
"""
self.dataset = dataset self.dataset = dataset
self.shuffle = shuffle self.shuffle = shuffle
@ -208,6 +216,15 @@ class RandomSampler(ReproducibleIterator):
class ReproducibleBatchSampler: class ReproducibleBatchSampler:
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象ReproducibleBatchSampler 将首先遍历一边该对象然后将迭代
出来的序号暂存起来使用时按照 batch_size batch 大小吐出序号列表
:param batch_size: 每个 batch 的大小是多少
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample 是否丢掉
:param kwargs: fastNLP 内部使用
"""
self.batch_sampler = batch_sampler self.batch_sampler = batch_sampler
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last

View File

@ -15,7 +15,7 @@ def remove_local_rank_in_argv():
""" """
index = -1 index = -1
for i, v in enumerate(sys.argv): for i, v in enumerate(sys.argv):
if v.startswith('--rank='): if v.startswith('--local_rank='):
os.environ['LOCAL_RANK'] = v.split('=')[1] os.environ['LOCAL_RANK'] = v.split('=')[1]
index = i index = i
break break

View File

@ -3,4 +3,4 @@ prettytable>=0.7.2
requests requests
regex!=2019.12.17 regex!=2019.12.17
rich==11.2.0 rich==11.2.0
# fsspec[http]>=2021.05.0, !=2021.06.0 packaging

View File

@ -1,12 +1,9 @@
import pytest import pytest
import sys
import os import os
import numpy as np import numpy as np
from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
set_env_on_import_paddle() set_env_on_import_paddle()
set_env("paddle")
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
@ -54,6 +51,7 @@ def test_move_data_to_device():
dist.barrier() dist.barrier()
@magic_argv_env_context @magic_argv_env_context
def test_is_distributed(): def test_is_distributed():
print(os.getenv("CUDA_VISIBLE_DEVICES")) print(os.getenv("CUDA_VISIBLE_DEVICES"))
@ -64,6 +62,7 @@ def test_is_distributed():
driver = PaddleFleetDriver( driver = PaddleFleetDriver(
model=paddle_model, model=paddle_model,
parallel_device=[0,1], parallel_device=[0,1],
output_from_new_proc='all'
) )
driver.set_optimizers(paddle_opt) driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候 # 区分launch和子进程setup的时候
@ -79,6 +78,7 @@ def test_is_distributed():
synchronize_safe_rm("log") synchronize_safe_rm("log")
dist.barrier() dist.barrier()
@magic_argv_env_context @magic_argv_env_context
def test_get_no_sync_context(): def test_get_no_sync_context():
""" """
@ -105,6 +105,7 @@ def test_get_no_sync_context():
synchronize_safe_rm("log") synchronize_safe_rm("log")
dist.barrier() dist.barrier()
@magic_argv_env_context @magic_argv_env_context
def test_is_global_zero(): def test_is_global_zero():
try: try:
@ -128,6 +129,8 @@ def test_is_global_zero():
synchronize_safe_rm("log") synchronize_safe_rm("log")
dist.barrier() dist.barrier()
@magic_argv_env_context @magic_argv_env_context
def test_unwrap_model(): def test_unwrap_model():
try: try:
@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible):
else: else:
driver.setup() driver.setup()
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
driver.replace_sampler(dataloader, dist_sampler, reproducible) driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
finally: finally:
synchronize_safe_rm("log") synchronize_safe_rm("log")
dist.barrier() dist.barrier()
@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase:
parallel_device=gpus, parallel_device=gpus,
) )
driver.set_optimizers(paddle_opt) driver.set_optimizers(paddle_opt)
dataloader = driver.replace_sampler(dataloader) dataloader = driver.set_dist_repro_dataloader(dataloader, )
driver.setup() driver.setup()
# 检查model_device # 检查model_device
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")

View File

@ -164,4 +164,4 @@ class TestSingleDeviceFunction:
""" """
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible) res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)

View File

@ -33,11 +33,15 @@ def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerReproducibleBatchSampler # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerReproducibleBatchSampler
# reproducible 是 True 和 False # reproducible 是 True 和 False
# 需要 check 返回的 sampler 和 dataloader 都不同了
assert driver.is_distributed() is False, "This test only for non distributed sampler." assert driver.is_distributed() is False, "This test only for non distributed sampler."
ds = SequenceDataSet(10) ds = SequenceDataSet(10)
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True) dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True)
assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one."
assert not (dl1 is dataloader), "The dataloader should not the same one."
# 迭代两个 batch # 迭代两个 batch
already_seen_idx = set() already_seen_idx = set()
@ -68,6 +72,22 @@ def check_replace_sampler(driver):
assert b not in already_seen_idx assert b not in already_seen_idx
assert b in left_idxes assert b in left_idxes
# 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad2所有卡互相不重复
ds = SequenceDataSet(11)
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True)
world_size = 3
indices = []
for i in range(world_size):
dl1.sampler.set_distributed(num_replicas=world_size, rank=i)
for idx, batch in dl1:
indices.extend(batch)
assert len(indices)==len(ds) # 应该没有任何重复
assert len(set(indices))==len(indices) # 应该全是不一样的indice