跟进断点重训的设置

This commit is contained in:
x54-729 2022-04-10 14:59:45 +00:00
parent da849564d6
commit 9678c559c9
4 changed files with 271 additions and 112 deletions

View File

@ -10,6 +10,7 @@ from .utils import (
_MODE_PARAMETER,
get_device_from_visible,
reset_seed,
replace_sampler
)
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@ -19,8 +20,13 @@ from fastNLP.core.utils import (
paddle_move_data_to_device,
is_in_paddle_dist,
)
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.samplers import (
ReproducibleIterator,
RandomSampler,
UnrepeatedDistributedSampler,
re_instantiate_sampler,
)
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE:
@ -314,23 +320,15 @@ class PaddleFleetDriver(PaddleDriver):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist
return dataloader
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员
# 因此用 type() 进行严格的判断
if type(dataloader.batch_sampler) == BatchSampler:
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler)
else:
shuffle = dataloader.batch_sampler.shuffle
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
# trainer, evaluator
# 自己初始化了分布式,什么都不做
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
@ -339,40 +337,40 @@ class PaddleFleetDriver(PaddleDriver):
return dataloader
# trainer
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
dataloader.batch_sampler.sampler.set_distributed(
if isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return dataloader
return replace_sampler(dataloader, sampler)
else:
sampler = RandomSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))
dataset=args.dataset,
shuffle=args.shuffle,
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
dataloader.batch_sampler.sampler = sampler
return dataloader
return replace_sampler(dataloader, sampler)
# evaluator
elif dist == "unrepeatdist":
args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))
dataset=args.dataset,
shuffle=args.shuffle,
)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank
)
dataloader.batch_sampler.sampler = sampler
return dataloader
return replace_sampler(dataloader, sampler)
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

View File

@ -1,21 +1,31 @@
import os
import random
from typing import Union, Optional, Callable, Dict
from typing import Union, Optional, Dict
from pathlib import Path
from functools import partial
from dataclasses import dataclass
import numpy as np
from .utils import _build_fp16_env
from .utils import _build_fp16_env, optimizer_state_to_device
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler
if _NEED_IMPORT_PADDLE:
import paddle
from paddle.io import DataLoader, IterableDataset
from paddle.io import (
DataLoader,
IterableDataset,
Dataset,
Sampler,
BatchSampler,
RandomSampler,
)
from paddle.optimizer import Optimizer
_reduces = {
@ -69,6 +79,8 @@ class PaddleDriver(Driver):
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
if isinstance(dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if dataloader.batch_sampler is None and dataloader.batch_size is None:
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.")
else:
if not isinstance(dataloader, Dict):
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
@ -79,6 +91,9 @@ class PaddleDriver(Driver):
f"type, not {type(each_dataloader)}.")
if isinstance(each_dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if dataloader.batch_sampler is None and dataloader.batch_size is None:
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
f"`batch_sampler` and `batch_size` should be set.")
@staticmethod
def _check_optimizer_legality(optimizers):
@ -153,45 +168,53 @@ class PaddleDriver(Driver):
getattr(self.model, mode)()
@rank_zero_call
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs):
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
保存模型的函数注意函数 `save` 是用来进行断点重训的函数
如果 `model_save_fn` 是一个可调用的函数那么我们会直接运行该函数
:param filepath: 保存文件的文件位置需要包括文件名
:param only_state_dict: 是否只保存模型的 `state_dict`注意该参数仅当 `model_save_fn` None 时有效
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数如果该参数不为 None那么我们会调用 model_save_fn(path)
:param only_state_dict: 是否只保存模型的 `state_dict`
:param kwargs:
:return:
"""
if model_save_fn is not None:
model_save_fn(filepath)
model = self.unwrap_model()
if only_state_dict:
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(states, filepath)
else:
model = self.unwrap_model()
if only_state_dict:
paddle.save(model.state_dict(), filepath)
# paddle 在保存整个模型时需要传入额外参数
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
if self.model_device is not None:
if not self.is_distributed():
self.move_model_to_device(model, "cpu")
paddle.jit.save(model, filepath, input_spec)
if not self.is_distributed():
self.move_model_to_device(model, self.model_device)
else:
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.")
paddle.jit.save(model, filepath, input_spec)
@staticmethod
@rank_zero_call
def load_model(filepath: str, load_dict: bool = True):
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
加载模型的函数注意函数 `load` 是用来进行断点重训的函数
:param filepath: 需要被加载的对象的文件位置需要包括文件名
:param load_dict: 是否加载state_dict默认为True当用户在save_model时将only_state_dict设置为False时
即保存了整个模型时这个参数必须也为False
:return: 返回加载指定文件后的结果
:param kwargs:
:return:
"""
if load_dict:
return paddle.load(filepath)
model = self.unwrap_model()
if only_state_dict:
model.load_dict(paddle.load(filepath))
else:
return paddle.jit.load(filepath)
model.load_dict(paddle.jit.load(filepath).state_dict())
@rank_zero_call
def save(self, folder, states: Dict):
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
r"""
断点重训的保存函数该函数会负责保存模型和 optimizers state_dict
需要注意 driver 应当是无状态的即不管什么时候调用 driver 的接口函数其返回的结果应该都是一样的因此断点重训不需要保存 driver
@ -203,48 +226,110 @@ class PaddleDriver(Driver):
:param states: trainer 传入的一个字典其中已经包含了为了实现断点重训所需要保存的其它对象的状态Driver 应该只需要保存
该对象即可 Driver 应该不需要理解该对象同时在 driver.load() 的时候需要将 states 返回回去load()返回的值与这里的
传入的值保持一致
:param dataloader: 正在使用的 dataloader需要保存里面的状态使得之后可以从当前迭代的位置恢复
:param only_state_dict: 是否只保存模型的参数 should_save_model False 该参数无效
:param should_save_model: 是否应该保存模型如果为FalseDriver 将不负责 model 的保存
:return:
"""
# 1. 保存模型的状态;
model = self.unwrap_model()
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
# 对于单卡的 driver 来讲我们实际上现在不应该考虑用户在DDP环境下使用单卡模式从而造成效率损失
states["model_state_dict"] = model_state_dict
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
# 2. 保存 optimizers 的状态;
# 1. sampler 的状态,因为我们支持 resume training即精确恢复到具体的一个 batch
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None也可能为用户设置的 batch_sampler
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
# 2. 保存模型的状态;
if should_save_model:
model = self.unwrap_model()
if only_state_dict:
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME))
logger.debug("Save model state dict")
else:
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec)
logger.debug("Save model")
# 3. 保存 optimizers 的状态;
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizer_state = optimizer.state_dict()
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()}
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy测试是不需要的
logger.debug("Save optimizer state dict")
states["optimizers_state_dict"] = optimizers_state_dict
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
paddle.save(states, folder)
def load(self, filepath) -> Dict:
r"""
断点重训的加载函数注意该函数会负责读取数据并且恢复模型和 optimizers state_dict
driver 实例需要在该函数中先加载模型和 optimizers state_dict然后将一个 state 字典返回给 trainer
因此 save 函数和 load 函数的接受和返回值应该是对应的
该函数需要在所有 rank 上执行
:param filepath: 保存断点重训的状态的文件名
:return: 需要返回 save 函数输入的 states 内容
"""
states = paddle.load(filepath)
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
for i in range(len(self.optimizers)):
optimizer: paddle.optimizer.Optimizer = self.optimizers[i]
optimizer: Optimizer = self.optimizers[i]
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
logger.debug("Load optimizer state dict.")
# 2. 加载模型状态;
model = self.unwrap_model()
model.load_dict(states["model_state_dict"])
if should_load_model:
model = self.unwrap_model()
if only_state_dict:
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME))
model.load_dict(res)
logger.debug("Load model state dict.")
else:
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict())
logger.debug("Load model.")
# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_sampler.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler
else:
batch_idx_in_epoch = sampler.batch_idx_in_epoch
states["batch_idx_in_epoch"] = batch_idx_in_epoch
self.barrier()
return states
def get_evaluate_context(self):
@ -313,3 +398,53 @@ class PaddleDriver(Driver):
"""
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
dataloader.batch_sampler.set_epoch(cur_epoch_idx)
@staticmethod
def get_dataloader_args(dataloader: "DataLoader"):
"""
获取 dataloader shuffle drop_last 属性
"""
@dataclass
class Res:
dataset: Optional[Dataset] = None
batch_sampler: Optional[BatchSampler] = None
sampler: Optional[Sampler] = None
batch_size: Optional[int] = None
shuffle: Optional[bool] = None
drop_last: Optional[bool] = None
res = Res()
# paddle 的 DataLoader 一定会有 dataset 属性;
res.dataset = dataloader.dataset
if dataloader.batch_sampler is not None:
res.batch_sampler = dataloader.batch_sampler
if hasattr(dataloader.batch_sampler, "batch_size"):
res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
else:
dataloader_iter = iter(dataloader)
pre_sample = next(dataloader_iter)
res.batch_size = pre_sample.shape[0]
if hasattr(dataloader.batch_sampler, "sampler"):
res.sampler = dataloader.batch_sampler.sampler
if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
res.shuffle = dataloader.batch_sampler.sampler.shuffle
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler):
res.shuffle = True
else:
res.shuffle = False
else:
res.sampler = None
res.shuffle = False
if hasattr(dataloader.batch_sampler, "drop_last"):
res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
else:
res.drop_last = False
return res

View File

@ -2,6 +2,7 @@ import os
from typing import Optional, Dict, Union
from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.utils import (
@ -10,7 +11,7 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler
from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE:
@ -93,11 +94,8 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward
def setup(self):
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
device_id = get_paddle_device_id(self.model_device)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
device_id = user_visible_devices.split(",")[device_id]
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")
@ -145,26 +143,25 @@ class PaddleSingleDriver(PaddleDriver):
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist
return dataloader
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
return replace_sampler(dataloader, dist)
if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
return dataloader
args = self.get_dataloader_args(dataloader)
if isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
dataloader.batch_sampler = batch_sampler
return dataloader
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

View File

@ -9,7 +9,7 @@ from enum import IntEnum
from typing import Dict, Optional, Union
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger
@ -272,11 +272,9 @@ def get_device_from_visible(device: Union[str, int]):
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visible_devices.split(",")[idx]
else:
idx = str(idx)
if user_visible_devices is None:
raise RuntimeError("This situation cannot happen, please report a bug to us.")
idx = user_visible_devices.split(",")[idx]
cuda_visible_devices_list = cuda_visible_devices.split(',')
assert idx in cuda_visible_devices_list, "Can't find "\
@ -285,31 +283,44 @@ def get_device_from_visible(device: Union[str, int]):
res = cuda_visible_devices_list.index(idx)
return res
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
# 拿到实例属性;
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"):
"""
利用 `batch_sampler` 重新构建一个 DataLoader起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用
考虑了用户自己定制了 DataLoader 的情形
"""
# 拿到非下划线开头的实例属性;
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}
# 拿到 dataloader '__init__' 函数的默认函数签名;
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型
init_params = dict(inspect.signature(dataloader.__init__).parameters)
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler另一方面用户
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
# 中寻找;
# 中寻找;VAR_KEYWORD 代表 **kwargs
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters))
del init_params["self"]
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍;
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来
non_default_params = {name for name, p in init_params.items() if
name in instance_attrs and p.default != instance_attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_default_params.add("dataset")
# 收集不是默认值的参数和它的值
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1})
# persistent_workers 在类中的对应成员带有下划线,因此添加进来
reconstruct_args.update({
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
"persistent_workers": dataloader._persistent_workers,
})
# POSITIONAL_OR_KEYWORD 代表一般的参数
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
# 也即它们没有在初始化函数和实例成员中同时出现
required_args = {
p.name
for p in init_params.values()
@ -323,12 +334,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedBatchSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
)
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader而是定制了 DataLoader但是 __init__ 中没有 **kwargs
@ -340,12 +348,33 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
"manually add the `DistributedBatchSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
)
return type(dataloader)(**reconstruct_args)
def replace_sampler(dataloader, new_sampler):
"""
使用 `new_sampler` 重新构建一个 BatchSampler并替换到 `dataloader`
"""
new_batch_sampler = BatchSampler(
dataset=dataloader.batch_sampler.dataset,
sampler=new_sampler,
shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler),
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.batch_sampler.drop_last
)
return replace_batch_sampler(dataloader, new_batch_sampler)
def optimizer_state_to_device(state, device):
new_state = {}
for name, param in state.items():
if isinstance(param, dict):
new_state[name] = optimizer_state_to_device(param, device)
elif isinstance(param, paddle.Tensor):
new_state[name] = paddle_to(param, device).clone()
else:
new_state[name] = param
return new_state