mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
删除了driver的replace_sampler替换为set_dist_repro_dataloader; 同时修改 driver.load/driver.save 函数
This commit is contained in:
parent
5b54a0cd73
commit
8e4abf2aa5
@ -124,11 +124,7 @@ class Evaluator:
|
|||||||
|
|
||||||
self.dataloaders = {}
|
self.dataloaders = {}
|
||||||
for name, dl in dataloaders.items(): # 替换为正确的 sampler
|
for name, dl in dataloaders.items(): # 替换为正确的 sampler
|
||||||
dl = self.driver.replace_sampler(
|
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
|
||||||
dataloader=dl,
|
|
||||||
dist_sampler=self._dist_sampler,
|
|
||||||
reproducible=False
|
|
||||||
)
|
|
||||||
self.dataloaders[name] = dl
|
self.dataloaders[name] = dl
|
||||||
|
|
||||||
self.progress_bar = kwargs.get('progress_bar', 'auto')
|
self.progress_bar = kwargs.get('progress_bar', 'auto')
|
||||||
|
@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger):
|
|||||||
self.dataloader = self.train_dataloader
|
self.dataloader = self.train_dataloader
|
||||||
self.driver.set_deterministic_dataloader(self.dataloader)
|
self.driver.set_deterministic_dataloader(self.dataloader)
|
||||||
|
|
||||||
self.dataloader = self.driver.replace_sampler(
|
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
|
||||||
dataloader=self.train_dataloader,
|
reproducible=self.callback_manager.has_trainer_chechpoint)
|
||||||
dist_sampler=_dist_sampler,
|
|
||||||
reproducible=self.callback_manager.has_trainer_chechpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
|
self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
|
||||||
self.on_after_trainer_initialized(self.driver)
|
self.on_after_trainer_initialized(self.driver)
|
||||||
@ -578,22 +575,6 @@ class Trainer(TrainerEventTrigger):
|
|||||||
else:
|
else:
|
||||||
states["val_filter_state"] = None
|
states["val_filter_state"] = None
|
||||||
|
|
||||||
# 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
|
|
||||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
|
|
||||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
|
|
||||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader)
|
|
||||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
|
||||||
sampler = dataloader_args.batch_sampler
|
|
||||||
elif dataloader_args.sampler:
|
|
||||||
sampler = dataloader_args.sampler
|
|
||||||
else:
|
|
||||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
|
|
||||||
|
|
||||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
|
|
||||||
states['sampler_states'] = sampler.state_dict()
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
|
|
||||||
if isinstance(folder, str):
|
if isinstance(folder, str):
|
||||||
folder = Path(folder)
|
folder = Path(folder)
|
||||||
|
|
||||||
@ -601,9 +582,9 @@ class Trainer(TrainerEventTrigger):
|
|||||||
if not callable(model_save_fn):
|
if not callable(model_save_fn):
|
||||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
|
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
|
||||||
rank_zero_call(model_save_fn)(folder)
|
rank_zero_call(model_save_fn)(folder)
|
||||||
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs)
|
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.driver.save(folder=folder, states=states,
|
self.driver.save(folder=folder, dataloader=self.dataloader, states=states,
|
||||||
only_state_dict=only_state_dict, should_save_model=True, **kwargs)
|
only_state_dict=only_state_dict, should_save_model=True, **kwargs)
|
||||||
|
|
||||||
self.driver.barrier()
|
self.driver.barrier()
|
||||||
@ -616,9 +597,6 @@ class Trainer(TrainerEventTrigger):
|
|||||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator;
|
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator;
|
||||||
|
|
||||||
注意我们目前不支持单卡到多卡的断点重训;
|
注意我们目前不支持单卡到多卡的断点重训;
|
||||||
TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训;
|
|
||||||
因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的
|
|
||||||
sampler,不管其是第一次断点重训,还是之后的加载的重新训练;
|
|
||||||
|
|
||||||
:param folder: 保存断点重训 states 的文件地址;
|
:param folder: 保存断点重训 states 的文件地址;
|
||||||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
|
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
|
||||||
@ -627,33 +605,23 @@ class Trainer(TrainerEventTrigger):
|
|||||||
self.driver.barrier()
|
self.driver.barrier()
|
||||||
if isinstance(folder, str):
|
if isinstance(folder, str):
|
||||||
folder = Path(folder)
|
folder = Path(folder)
|
||||||
|
|
||||||
|
dataloader = self.dataloader
|
||||||
|
if not resume_training:
|
||||||
|
dataloader = None
|
||||||
|
|
||||||
if model_load_fn is not None:
|
if model_load_fn is not None:
|
||||||
if not callable(model_load_fn):
|
if not callable(model_load_fn):
|
||||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
|
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
|
||||||
rank_zero_call(model_load_fn)(folder)
|
rank_zero_call(model_load_fn)(folder)
|
||||||
states = self.driver.load(folder=folder, should_load_model=False, **kwargs)
|
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
|
||||||
else:
|
else:
|
||||||
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
|
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
|
||||||
|
|
||||||
if not resume_training:
|
if not resume_training:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. 恢复 sampler 的状态;
|
self.dataloader = states.pop('dataloader')
|
||||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader)
|
|
||||||
|
|
||||||
sampler = dataloader_args.sampler
|
|
||||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
|
|
||||||
# 说明这里需要使用 ReproduceSampler 来弄一下了
|
|
||||||
if self.driver.is_distributed():
|
|
||||||
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.")
|
|
||||||
sampler = ReproducibleBatchSampler(
|
|
||||||
batch_sampler=sampler,
|
|
||||||
batch_size=dataloader_args.batch_size,
|
|
||||||
drop_last=dataloader_args.drop_last
|
|
||||||
)
|
|
||||||
sampler.load_state_dict(states['sampler_states'])
|
|
||||||
|
|
||||||
self.driver.replace_sampler(self.dataloader, sampler)
|
|
||||||
|
|
||||||
# 2. validate filter state;
|
# 2. validate filter state;
|
||||||
if self.evaluator is not None:
|
if self.evaluator is not None:
|
||||||
@ -668,22 +636,16 @@ class Trainer(TrainerEventTrigger):
|
|||||||
|
|
||||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||||
if not isinstance(sampler, ReproducibleBatchSampler):
|
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
|
||||||
if dataloader_args.drop_last:
|
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
|
||||||
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
|
||||||
else:
|
|
||||||
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
|
|
||||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
|
|
||||||
# sampler 是 batch_sampler;
|
|
||||||
else:
|
|
||||||
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch
|
|
||||||
|
|
||||||
# 5. 恢复所有 callback 的状态;
|
# 5. 恢复所有 callback 的状态;
|
||||||
self.on_load_checkpoint(states["callback_states"])
|
self.on_load_checkpoint(states["callback_states"])
|
||||||
|
|
||||||
self.driver.barrier()
|
self.driver.barrier()
|
||||||
|
|
||||||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """
|
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
with self.driver.auto_cast():
|
with self.driver.auto_cast():
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union
|
from typing import Any, Sequence, List, Optional, Callable, Dict, Union
|
||||||
from abc import ABC
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -14,7 +14,6 @@ __all__ = [
|
|||||||
from fastNLP.core.utils import nullcontext
|
from fastNLP.core.utils import nullcontext
|
||||||
|
|
||||||
|
|
||||||
# todo 航总 check 一下哪一些方法需要 @abstractmethod;
|
|
||||||
class Driver(ABC):
|
class Driver(ABC):
|
||||||
r"""
|
r"""
|
||||||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类;
|
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类;
|
||||||
@ -32,29 +31,33 @@ class Driver(ABC):
|
|||||||
# self._consensus_file: Optional[Union[str, Path]] = None
|
# self._consensus_file: Optional[Union[str, Path]] = None
|
||||||
self._pids: Optional[List[int]] = None
|
self._pids: Optional[List[int]] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def setup(self):
|
def setup(self):
|
||||||
r"""
|
r"""
|
||||||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等;
|
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等;
|
||||||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值;
|
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False):
|
||||||
r"""
|
r"""
|
||||||
因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们
|
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。
|
||||||
需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换
|
|
||||||
为 reproducible sampler;
|
|
||||||
|
|
||||||
:param dataloader: 由 trainer 中传入的原始的 dataloader;
|
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
|
||||||
:param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler;
|
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
|
||||||
目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None;
|
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
|
||||||
evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||||
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver,
|
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||||
并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler;
|
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||||
如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler;
|
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||||
|
可以可以加载。
|
||||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;
|
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
|
||||||
|
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
|
||||||
|
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.")
|
if dist is None and reproducible is False:
|
||||||
|
return dataloader
|
||||||
|
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` "
|
||||||
|
f"function.")
|
||||||
|
|
||||||
def set_deterministic_dataloader(self, dataloader):
|
def set_deterministic_dataloader(self, dataloader):
|
||||||
r"""
|
r"""
|
||||||
@ -68,7 +71,7 @@ class Driver(ABC):
|
|||||||
|
|
||||||
:param cur_epoch_idx: 当前是第几个 epoch;
|
:param cur_epoch_idx: 当前是第几个 epoch;
|
||||||
"""
|
"""
|
||||||
|
@abstractmethod
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
"""
|
"""
|
||||||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程;
|
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程;
|
||||||
@ -103,7 +106,7 @@ class Driver(ABC):
|
|||||||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
|
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
|
||||||
我们应当提醒用户这一行为;
|
我们应当提醒用户这一行为;
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
@ -234,6 +237,7 @@ class Driver(ABC):
|
|||||||
"""
|
"""
|
||||||
self.optimizers = optimizers
|
self.optimizers = optimizers
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
"""
|
"""
|
||||||
实现深度学习中的反向传播过程;
|
实现深度学习中的反向传播过程;
|
||||||
@ -242,12 +246,14 @@ class Driver(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `backward` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `backward` function.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def step(self):
|
def step(self):
|
||||||
r"""
|
r"""
|
||||||
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数;
|
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数;
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `step` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `step` function.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def zero_grad(self, set_to_none: bool = False):
|
def zero_grad(self, set_to_none: bool = False):
|
||||||
r"""
|
r"""
|
||||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
|
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
|
||||||
@ -286,6 +292,7 @@ class Driver(ABC):
|
|||||||
def auto_cast(self, auto_cast):
|
def auto_cast(self, auto_cast):
|
||||||
self._auto_cast = auto_cast
|
self._auto_cast = auto_cast
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs):
|
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
|
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
|
||||||
@ -296,6 +303,7 @@ class Driver(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs):
|
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。
|
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。
|
||||||
@ -307,7 +315,8 @@ class Driver(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")
|
||||||
|
|
||||||
def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
@abstractmethod
|
||||||
|
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True)
|
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True)
|
||||||
|
|
||||||
@ -317,12 +326,14 @@ class Driver(ABC):
|
|||||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
|
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
|
||||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的
|
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的
|
||||||
传入的值保持一致。
|
传入的值保持一致。
|
||||||
|
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
|
||||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
|
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
|
||||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
|
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `save` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `save` function.")
|
||||||
|
|
||||||
def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
|
@abstractmethod
|
||||||
|
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||||
r"""
|
r"""
|
||||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和;
|
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和;
|
||||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。
|
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。
|
||||||
@ -331,11 +342,22 @@ class Driver(ABC):
|
|||||||
|
|
||||||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME
|
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME
|
||||||
(如果 should_load_model 为True)。
|
(如果 should_load_model 为True)。
|
||||||
|
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader'
|
||||||
|
以及 'batch_idx_in_epoch' 这两个值。
|
||||||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为
|
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为
|
||||||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
|
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
|
||||||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有
|
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有
|
||||||
找到对应的模型状态,则报错。
|
找到对应的模型状态,则报错。
|
||||||
:return: 需要返回 save 函数输入的 states 内容;
|
:return: 需要返回 save 函数输入的 states 内容
|
||||||
|
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。
|
||||||
|
在保存与当前传入 data sample 数目不一致时报错。
|
||||||
|
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的
|
||||||
|
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式
|
||||||
|
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。
|
||||||
|
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此
|
||||||
|
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下
|
||||||
|
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
|
||||||
|
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `load` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `load` function.")
|
||||||
|
|
||||||
@ -352,6 +374,7 @@ class Driver(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.")
|
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def set_model_mode(self, mode: str):
|
def set_model_mode(self, mode: str):
|
||||||
r"""
|
r"""
|
||||||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式;
|
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式;
|
||||||
@ -378,6 +401,7 @@ class Driver(ABC):
|
|||||||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数;
|
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def move_data_to_device(self, batch):
|
def move_data_to_device(self, batch):
|
||||||
r"""
|
r"""
|
||||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
|
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
|
||||||
@ -399,17 +423,6 @@ class Driver(ABC):
|
|||||||
仅在多分布式训练场景中有使用。
|
仅在多分布式训练场景中有使用。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_dataloader_args(dataloader):
|
|
||||||
"""
|
|
||||||
用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key:
|
|
||||||
sampler, batch_sampler, batch_size, drop_last;
|
|
||||||
|
|
||||||
:param dataloader:
|
|
||||||
:return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用;
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.")
|
|
||||||
|
|
||||||
def is_distributed(self) -> bool:
|
def is_distributed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
当前的 driver 实例是否是分布式的;
|
当前的 driver 实例是否是分布式的;
|
||||||
|
@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver):
|
|||||||
def test_step(self, batch):
|
def test_step(self, batch):
|
||||||
return self._test_step(batch)
|
return self._test_step(batch)
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||||
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
|
@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver):
|
|||||||
def is_distributed(self):
|
def is_distributed(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
||||||
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
# reproducible 的相关功能暂时没有实现
|
# reproducible 的相关功能暂时没有实现
|
||||||
if isinstance(dist_sampler, ReproducibleBatchSampler):
|
if isinstance(dist, ReproducibleBatchSampler):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
dataloader.batch_sampler = dist_sample
|
dataloader.batch_sampler = dist_sample
|
||||||
if isinstance(dist_sampler, ReproducibleIterator):
|
if isinstance(dist, ReproducibleIterator):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
dataloader.batch_sampler.sampler = dist_sampler
|
dataloader.batch_sampler.sampler = dist
|
||||||
|
|
||||||
if reproducible:
|
if reproducible:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -316,13 +316,14 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
def test_step(self, batch):
|
def test_step(self, batch):
|
||||||
return self._test_step(batch)
|
return self._test_step(batch)
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||||
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
|
|
||||||
# 暂时不支持iterableDataset
|
# 暂时不支持iterableDataset
|
||||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||||
"FastNLP does not support `IteratorDataset` now."
|
"FastNLP does not support `IteratorDataset` now."
|
||||||
if isinstance(dist_sampler, ReproducibleIterator):
|
if isinstance(dist, ReproducibleIterator):
|
||||||
dataloader.batch_sampler.sampler = dist_sampler
|
dataloader.batch_sampler.sampler = dist
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
|
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
|
||||||
@ -334,14 +335,14 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
shuffle = dataloader.batch_sampler.shuffle
|
shuffle = dataloader.batch_sampler.shuffle
|
||||||
|
|
||||||
# trainer, evaluator
|
# trainer, evaluator
|
||||||
if dist_sampler is None:
|
if dist is None:
|
||||||
if reproducible:
|
if reproducible:
|
||||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
|
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
|
||||||
"control.")
|
"control.")
|
||||||
else:
|
else:
|
||||||
return dataloader
|
return dataloader
|
||||||
# trainer
|
# trainer
|
||||||
elif dist_sampler == "dist":
|
elif dist == "dist":
|
||||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
|
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
|
||||||
dataloader.batch_sampler.sampler.set_distributed(
|
dataloader.batch_sampler.sampler.set_distributed(
|
||||||
@ -364,7 +365,7 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
dataloader.batch_sampler.sampler = sampler
|
dataloader.batch_sampler.sampler = sampler
|
||||||
return dataloader
|
return dataloader
|
||||||
# evaluator
|
# evaluator
|
||||||
elif dist_sampler == "unrepeatdist":
|
elif dist == "unrepeatdist":
|
||||||
sampler = UnrepeatedDistributedSampler(
|
sampler = UnrepeatedDistributedSampler(
|
||||||
dataset=dataloader.dataset,
|
dataset=dataloader.dataset,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
|
@ -133,15 +133,16 @@ class PaddleSingleDriver(PaddleDriver):
|
|||||||
"""
|
"""
|
||||||
return paddle_move_data_to_device(batch, "gpu:0")
|
return paddle_move_data_to_device(batch, "gpu:0")
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
||||||
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
# 暂时不支持IteratorDataset
|
# 暂时不支持IteratorDataset
|
||||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||||
"FastNLP does not support `IteratorDataset` now."
|
"FastNLP does not support `IteratorDataset` now."
|
||||||
if isinstance(dist_sampler, ReproducibleBatchSampler):
|
if isinstance(dist, ReproducibleBatchSampler):
|
||||||
dataloader.batch_sampler = dist_sampler
|
dataloader.batch_sampler = dist
|
||||||
return dataloader
|
return dataloader
|
||||||
if isinstance(dist_sampler, ReproducibleIterator):
|
if isinstance(dist, ReproducibleIterator):
|
||||||
dataloader.batch_sampler.sampler = dist_sampler
|
dataloader.batch_sampler.sampler = dist
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
if reproducible:
|
if reproducible:
|
||||||
|
@ -445,21 +445,22 @@ class TorchDDPDriver(TorchDriver):
|
|||||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
|
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
|
||||||
return self._test_step(batch)
|
return self._test_step(batch)
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||||
if isinstance(dist_sampler, ReproducibleIterator):
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
|
if isinstance(dist, ReproducibleIterator):
|
||||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||||
dist_sampler = re_instantiate_sampler(dist_sampler)
|
dist = re_instantiate_sampler(dist)
|
||||||
return replace_sampler(dataloader, dist_sampler)
|
return replace_sampler(dataloader, dist)
|
||||||
|
|
||||||
# trainer, evaluator
|
# trainer, evaluator
|
||||||
if dist_sampler is None:
|
if dist is None:
|
||||||
if reproducible:
|
if reproducible:
|
||||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
||||||
"control.")
|
"control.")
|
||||||
else:
|
else:
|
||||||
return dataloader
|
return dataloader
|
||||||
# trainer
|
# trainer
|
||||||
elif dist_sampler == "dist":
|
elif dist == "dist":
|
||||||
args = self.get_dataloader_args(dataloader)
|
args = self.get_dataloader_args(dataloader)
|
||||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||||
if isinstance(args.sampler, ReproducibleIterator):
|
if isinstance(args.sampler, ReproducibleIterator):
|
||||||
@ -485,7 +486,7 @@ class TorchDDPDriver(TorchDriver):
|
|||||||
return replace_sampler(dataloader, sampler)
|
return replace_sampler(dataloader, sampler)
|
||||||
|
|
||||||
# evaluator
|
# evaluator
|
||||||
elif dist_sampler == "unrepeatdist":
|
elif dist == "unrepeatdist":
|
||||||
args = self.get_dataloader_args(dataloader)
|
args = self.get_dataloader_args(dataloader)
|
||||||
sampler = UnrepeatedDistributedSampler(
|
sampler = UnrepeatedDistributedSampler(
|
||||||
dataset=args.dataset,
|
dataset=args.dataset,
|
||||||
|
@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver):
|
|||||||
else:
|
else:
|
||||||
return self._test_step(batch)
|
return self._test_step(batch)
|
||||||
|
|
||||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
||||||
reproducible: bool = False):
|
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||||
if isinstance(dist_sampler, ReproducibleBatchSampler):
|
if isinstance(dist, ReproducibleBatchSampler):
|
||||||
return replace_batch_sampler(dataloader, dist_sampler)
|
return replace_batch_sampler(dataloader, dist)
|
||||||
elif isinstance(dist_sampler, ReproducibleIterator):
|
elif isinstance(dist, ReproducibleIterator):
|
||||||
return replace_sampler(dataloader, dist_sampler)
|
return replace_sampler(dataloader, dist)
|
||||||
|
|
||||||
if reproducible:
|
if reproducible:
|
||||||
args = self.get_dataloader_args(dataloader)
|
args = self.get_dataloader_args(dataloader)
|
||||||
|
@ -50,6 +50,14 @@ class ReproducibleIterator:
|
|||||||
|
|
||||||
class RandomSampler(ReproducibleIterator):
|
class RandomSampler(ReproducibleIterator):
|
||||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
|
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
:param dataset: 实现了 __len__ 方法的数据容器
|
||||||
|
:param shuffle: 是否在每次 iterate 的时候打乱顺序。
|
||||||
|
:param seed: 随机数种子。
|
||||||
|
:param kwargs: 用户不需要使用,fastNLP 内部使用
|
||||||
|
"""
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
@ -208,6 +216,15 @@ class RandomSampler(ReproducibleIterator):
|
|||||||
class ReproducibleBatchSampler:
|
class ReproducibleBatchSampler:
|
||||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
||||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
|
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
|
||||||
|
"""
|
||||||
|
可以使得 batch_sampler 对象状态恢复的 wrapper 。
|
||||||
|
|
||||||
|
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代
|
||||||
|
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
|
||||||
|
:param batch_size: 每个 batch 的大小是多少。
|
||||||
|
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
|
||||||
|
:param kwargs: fastNLP 内部使用。
|
||||||
|
"""
|
||||||
self.batch_sampler = batch_sampler
|
self.batch_sampler = batch_sampler
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
|
@ -15,7 +15,7 @@ def remove_local_rank_in_argv():
|
|||||||
"""
|
"""
|
||||||
index = -1
|
index = -1
|
||||||
for i, v in enumerate(sys.argv):
|
for i, v in enumerate(sys.argv):
|
||||||
if v.startswith('--rank='):
|
if v.startswith('--local_rank='):
|
||||||
os.environ['LOCAL_RANK'] = v.split('=')[1]
|
os.environ['LOCAL_RANK'] = v.split('=')[1]
|
||||||
index = i
|
index = i
|
||||||
break
|
break
|
||||||
|
@ -3,4 +3,4 @@ prettytable>=0.7.2
|
|||||||
requests
|
requests
|
||||||
regex!=2019.12.17
|
regex!=2019.12.17
|
||||||
rich==11.2.0
|
rich==11.2.0
|
||||||
# fsspec[http]>=2021.05.0, !=2021.06.0
|
packaging
|
@ -1,12 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastNLP.envs.set_backend import set_env
|
|
||||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
|
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
|
||||||
|
|
||||||
set_env_on_import_paddle()
|
set_env_on_import_paddle()
|
||||||
set_env("paddle")
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
from paddle.io import DataLoader
|
from paddle.io import DataLoader
|
||||||
@ -54,6 +51,7 @@ def test_move_data_to_device():
|
|||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
@magic_argv_env_context
|
@magic_argv_env_context
|
||||||
def test_is_distributed():
|
def test_is_distributed():
|
||||||
print(os.getenv("CUDA_VISIBLE_DEVICES"))
|
print(os.getenv("CUDA_VISIBLE_DEVICES"))
|
||||||
@ -64,6 +62,7 @@ def test_is_distributed():
|
|||||||
driver = PaddleFleetDriver(
|
driver = PaddleFleetDriver(
|
||||||
model=paddle_model,
|
model=paddle_model,
|
||||||
parallel_device=[0,1],
|
parallel_device=[0,1],
|
||||||
|
output_from_new_proc='all'
|
||||||
)
|
)
|
||||||
driver.set_optimizers(paddle_opt)
|
driver.set_optimizers(paddle_opt)
|
||||||
# 区分launch和子进程setup的时候
|
# 区分launch和子进程setup的时候
|
||||||
@ -79,6 +78,7 @@ def test_is_distributed():
|
|||||||
synchronize_safe_rm("log")
|
synchronize_safe_rm("log")
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
@magic_argv_env_context
|
@magic_argv_env_context
|
||||||
def test_get_no_sync_context():
|
def test_get_no_sync_context():
|
||||||
"""
|
"""
|
||||||
@ -105,6 +105,7 @@ def test_get_no_sync_context():
|
|||||||
synchronize_safe_rm("log")
|
synchronize_safe_rm("log")
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
@magic_argv_env_context
|
@magic_argv_env_context
|
||||||
def test_is_global_zero():
|
def test_is_global_zero():
|
||||||
try:
|
try:
|
||||||
@ -128,6 +129,8 @@ def test_is_global_zero():
|
|||||||
synchronize_safe_rm("log")
|
synchronize_safe_rm("log")
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@magic_argv_env_context
|
@magic_argv_env_context
|
||||||
def test_unwrap_model():
|
def test_unwrap_model():
|
||||||
try:
|
try:
|
||||||
@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible):
|
|||||||
else:
|
else:
|
||||||
driver.setup()
|
driver.setup()
|
||||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
|
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
|
||||||
driver.replace_sampler(dataloader, dist_sampler, reproducible)
|
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
|
||||||
finally:
|
finally:
|
||||||
synchronize_safe_rm("log")
|
synchronize_safe_rm("log")
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase:
|
|||||||
parallel_device=gpus,
|
parallel_device=gpus,
|
||||||
)
|
)
|
||||||
driver.set_optimizers(paddle_opt)
|
driver.set_optimizers(paddle_opt)
|
||||||
dataloader = driver.replace_sampler(dataloader)
|
dataloader = driver.set_dist_repro_dataloader(dataloader, )
|
||||||
driver.setup()
|
driver.setup()
|
||||||
# 检查model_device
|
# 检查model_device
|
||||||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")
|
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")
|
||||||
|
@ -164,4 +164,4 @@ class TestSingleDeviceFunction:
|
|||||||
"""
|
"""
|
||||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
|
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
|
||||||
|
|
||||||
res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible)
|
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
|
@ -33,11 +33,15 @@ def check_replace_sampler(driver):
|
|||||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler
|
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler
|
||||||
# reproducible 是 True 和 False
|
# reproducible 是 True 和 False
|
||||||
|
|
||||||
|
# 需要 check 返回的 sampler 和 dataloader 都不同了
|
||||||
assert driver.is_distributed() is False, "This test only for non distributed sampler."
|
assert driver.is_distributed() is False, "This test only for non distributed sampler."
|
||||||
ds = SequenceDataSet(10)
|
ds = SequenceDataSet(10)
|
||||||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
|
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
|
||||||
|
|
||||||
dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True)
|
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True)
|
||||||
|
|
||||||
|
assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one."
|
||||||
|
assert not (dl1 is dataloader), "The dataloader should not the same one."
|
||||||
|
|
||||||
# 迭代两个 batch
|
# 迭代两个 batch
|
||||||
already_seen_idx = set()
|
already_seen_idx = set()
|
||||||
@ -68,6 +72,22 @@ def check_replace_sampler(driver):
|
|||||||
assert b not in already_seen_idx
|
assert b not in already_seen_idx
|
||||||
assert b in left_idxes
|
assert b in left_idxes
|
||||||
|
|
||||||
|
# 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad;(2)所有卡互相不重复
|
||||||
|
ds = SequenceDataSet(11)
|
||||||
|
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
|
||||||
|
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True)
|
||||||
|
world_size = 3
|
||||||
|
indices = []
|
||||||
|
for i in range(world_size):
|
||||||
|
dl1.sampler.set_distributed(num_replicas=world_size, rank=i)
|
||||||
|
for idx, batch in dl1:
|
||||||
|
indices.extend(batch)
|
||||||
|
assert len(indices)==len(ds) # 应该没有任何重复
|
||||||
|
assert len(set(indices))==len(indices) # 应该全是不一样的indice
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user