mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
跟进断点重训的设置
This commit is contained in:
parent
da849564d6
commit
9678c559c9
@ -10,6 +10,7 @@ from .utils import (
|
||||
_MODE_PARAMETER,
|
||||
get_device_from_visible,
|
||||
reset_seed,
|
||||
replace_sampler
|
||||
)
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
@ -19,8 +20,13 @@ from fastNLP.core.utils import (
|
||||
paddle_move_data_to_device,
|
||||
is_in_paddle_dist,
|
||||
)
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
|
||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleIterator,
|
||||
RandomSampler,
|
||||
UnrepeatedDistributedSampler,
|
||||
re_instantiate_sampler,
|
||||
)
|
||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
@ -314,23 +320,15 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
|
||||
# 暂时不支持iterableDataset
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
"FastNLP does not support `IteratorDataset` now."
|
||||
if isinstance(dist, ReproducibleIterator):
|
||||
dataloader.batch_sampler.sampler = dist
|
||||
return dataloader
|
||||
|
||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
|
||||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员
|
||||
# 因此用 type() 进行严格的判断
|
||||
if type(dataloader.batch_sampler) == BatchSampler:
|
||||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler)
|
||||
else:
|
||||
shuffle = dataloader.batch_sampler.shuffle
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# trainer, evaluator
|
||||
# 自己初始化了分布式,什么都不做
|
||||
if dist is None:
|
||||
if reproducible:
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
|
||||
@ -339,40 +337,40 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
return dataloader
|
||||
# trainer
|
||||
elif dist == "dist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
|
||||
dataloader.batch_sampler.sampler.set_distributed(
|
||||
if isinstance(args.sampler, ReproducibleIterator):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return dataloader
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
sampler = RandomSampler(
|
||||
dataset=dataloader.dataset,
|
||||
shuffle=shuffle,
|
||||
seed=int(os.environ.get("FASTNLP_SEED", 0))
|
||||
dataset=args.dataset,
|
||||
shuffle=args.shuffle,
|
||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
|
||||
)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
dataloader.batch_sampler.sampler = sampler
|
||||
return dataloader
|
||||
return replace_sampler(dataloader, sampler)
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
sampler = UnrepeatedDistributedSampler(
|
||||
dataset=dataloader.dataset,
|
||||
shuffle=shuffle,
|
||||
seed=int(os.environ.get("FASTNLP_SEED", 0))
|
||||
dataset=args.dataset,
|
||||
shuffle=args.shuffle,
|
||||
)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank
|
||||
)
|
||||
dataloader.batch_sampler.sampler = sampler
|
||||
return dataloader
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
|
||||
|
||||
|
@ -1,21 +1,31 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Union, Optional, Callable, Dict
|
||||
from typing import Union, Optional, Dict
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import _build_fp16_env
|
||||
from .utils import _build_fp16_env, optimizer_state_to_device
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.core.drivers.driver import Driver
|
||||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
from paddle.io import DataLoader, IterableDataset
|
||||
from paddle.io import (
|
||||
DataLoader,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
Sampler,
|
||||
BatchSampler,
|
||||
RandomSampler,
|
||||
)
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
_reduces = {
|
||||
@ -69,6 +79,8 @@ class PaddleDriver(Driver):
|
||||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类;
|
||||
if isinstance(dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
if dataloader.batch_sampler is None and dataloader.batch_size is None:
|
||||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.")
|
||||
else:
|
||||
if not isinstance(dataloader, Dict):
|
||||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.")
|
||||
@ -79,6 +91,9 @@ class PaddleDriver(Driver):
|
||||
f"type, not {type(each_dataloader)}.")
|
||||
if isinstance(each_dataloader.dataset, IterableDataset):
|
||||
raise TypeError("`IterableDataset` is not allowed.")
|
||||
if dataloader.batch_sampler is None and dataloader.batch_size is None:
|
||||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
|
||||
f"`batch_sampler` and `batch_size` should be set.")
|
||||
|
||||
@staticmethod
|
||||
def _check_optimizer_legality(optimizers):
|
||||
@ -153,45 +168,53 @@ class PaddleDriver(Driver):
|
||||
getattr(self.model, mode)()
|
||||
|
||||
@rank_zero_call
|
||||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs):
|
||||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
r"""
|
||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
|
||||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数;
|
||||
|
||||
:param filepath: 保存文件的文件位置(需要包括文件名);
|
||||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效;
|
||||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path);
|
||||
:param only_state_dict: 是否只保存模型的 `state_dict`;
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if model_save_fn is not None:
|
||||
model_save_fn(filepath)
|
||||
model = self.unwrap_model()
|
||||
|
||||
if only_state_dict:
|
||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
|
||||
paddle.save(states, filepath)
|
||||
else:
|
||||
model = self.unwrap_model()
|
||||
if only_state_dict:
|
||||
paddle.save(model.state_dict(), filepath)
|
||||
# paddle 在保存整个模型时需要传入额外参数
|
||||
input_spec = kwargs.get("input_spec", None)
|
||||
if input_spec is None:
|
||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
|
||||
if self.model_device is not None:
|
||||
if not self.is_distributed():
|
||||
self.move_model_to_device(model, "cpu")
|
||||
paddle.jit.save(model, filepath, input_spec)
|
||||
if not self.is_distributed():
|
||||
self.move_model_to_device(model, self.model_device)
|
||||
else:
|
||||
input_spec = kwargs.get("input_spec", None)
|
||||
if input_spec is None:
|
||||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.")
|
||||
paddle.jit.save(model, filepath, input_spec)
|
||||
|
||||
@staticmethod
|
||||
@rank_zero_call
|
||||
def load_model(filepath: str, load_dict: bool = True):
|
||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
r"""
|
||||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数;
|
||||
|
||||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名);
|
||||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时,
|
||||
即保存了整个模型时,这个参数必须也为False
|
||||
:return: 返回加载指定文件后的结果;
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if load_dict:
|
||||
return paddle.load(filepath)
|
||||
model = self.unwrap_model()
|
||||
if only_state_dict:
|
||||
model.load_dict(paddle.load(filepath))
|
||||
else:
|
||||
return paddle.jit.load(filepath)
|
||||
model.load_dict(paddle.jit.load(filepath).state_dict())
|
||||
|
||||
@rank_zero_call
|
||||
def save(self, folder, states: Dict):
|
||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
r"""
|
||||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict;
|
||||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver
|
||||
@ -203,48 +226,110 @@ class PaddleDriver(Driver):
|
||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
|
||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的
|
||||
传入的值保持一致。
|
||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
|
||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
|
||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
|
||||
:return:
|
||||
"""
|
||||
# 1. 保存模型的状态;
|
||||
model = self.unwrap_model()
|
||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
|
||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失;
|
||||
states["model_state_dict"] = model_state_dict
|
||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变
|
||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
|
||||
|
||||
# 2. 保存 optimizers 的状态;
|
||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
|
||||
# paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif dataloader_args.sampler:
|
||||
sampler = dataloader_args.sampler
|
||||
else:
|
||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
|
||||
|
||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
|
||||
states['sampler_states'] = sampler.state_dict()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
|
||||
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
model = self.unwrap_model()
|
||||
if only_state_dict:
|
||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
|
||||
paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME))
|
||||
logger.debug("Save model state dict")
|
||||
else:
|
||||
input_spec = kwargs.get("input_spec", None)
|
||||
if input_spec is None:
|
||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
|
||||
paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec)
|
||||
logger.debug("Save model")
|
||||
|
||||
# 3. 保存 optimizers 的状态;
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()}
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
|
||||
logger.debug("Save optimizer state dict")
|
||||
states["optimizers_state_dict"] = optimizers_state_dict
|
||||
paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
paddle.save(states, folder)
|
||||
|
||||
def load(self, filepath) -> Dict:
|
||||
r"""
|
||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等;
|
||||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。
|
||||
因此 save 函数和 load 函数的接受和返回值应该是对应的;
|
||||
|
||||
该函数需要在所有 rank 上执行。
|
||||
|
||||
:param filepath: 保存断点重训的状态的文件名;
|
||||
:return: 需要返回 save 函数输入的 states 内容;
|
||||
"""
|
||||
states = paddle.load(filepath)
|
||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
|
||||
states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
optimizers_state_dict = states["optimizers_state_dict"]
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i]
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
|
||||
# 2. 加载模型状态;
|
||||
model = self.unwrap_model()
|
||||
model.load_dict(states["model_state_dict"])
|
||||
if should_load_model:
|
||||
model = self.unwrap_model()
|
||||
if only_state_dict:
|
||||
res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME))
|
||||
model.load_dict(res)
|
||||
logger.debug("Load model state dict.")
|
||||
else:
|
||||
model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict())
|
||||
logger.debug("Load model.")
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
sampler = dataloader_args.sampler
|
||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
|
||||
# 说明这里需要使用 ReproduceSampler 来弄一下了
|
||||
if self.is_distributed():
|
||||
raise RuntimeError(
|
||||
"It is not allowed to use single device checkpoint retraining before but ddp now.")
|
||||
sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=sampler,
|
||||
batch_size=dataloader_args.batch_sampler.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
)
|
||||
sampler.load_state_dict(states['sampler_states'])
|
||||
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||
if not isinstance(sampler, ReproducibleBatchSampler):
|
||||
if dataloader_args.drop_last:
|
||||
batch_idx_in_epoch = len(
|
||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
||||
else:
|
||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
|
||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
|
||||
# sampler 是 batch_sampler;
|
||||
else:
|
||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch
|
||||
|
||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch
|
||||
|
||||
self.barrier()
|
||||
return states
|
||||
|
||||
def get_evaluate_context(self):
|
||||
@ -313,3 +398,53 @@ class PaddleDriver(Driver):
|
||||
"""
|
||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
|
||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx)
|
||||
|
||||
@staticmethod
|
||||
def get_dataloader_args(dataloader: "DataLoader"):
|
||||
"""
|
||||
获取 dataloader 的 shuffle 和 drop_last 属性;
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Res:
|
||||
dataset: Optional[Dataset] = None
|
||||
batch_sampler: Optional[BatchSampler] = None
|
||||
sampler: Optional[Sampler] = None
|
||||
batch_size: Optional[int] = None
|
||||
shuffle: Optional[bool] = None
|
||||
drop_last: Optional[bool] = None
|
||||
|
||||
res = Res()
|
||||
|
||||
# paddle 的 DataLoader 一定会有 dataset 属性;
|
||||
res.dataset = dataloader.dataset
|
||||
|
||||
if dataloader.batch_sampler is not None:
|
||||
res.batch_sampler = dataloader.batch_sampler
|
||||
if hasattr(dataloader.batch_sampler, "batch_size"):
|
||||
res.batch_size = getattr(dataloader.batch_sampler, "batch_size")
|
||||
# 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性;
|
||||
else:
|
||||
dataloader_iter = iter(dataloader)
|
||||
pre_sample = next(dataloader_iter)
|
||||
res.batch_size = pre_sample.shape[0]
|
||||
|
||||
if hasattr(dataloader.batch_sampler, "sampler"):
|
||||
res.sampler = dataloader.batch_sampler.sampler
|
||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"):
|
||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle
|
||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler):
|
||||
res.shuffle = True
|
||||
else:
|
||||
res.shuffle = False
|
||||
else:
|
||||
res.sampler = None
|
||||
res.shuffle = False
|
||||
|
||||
if hasattr(dataloader.batch_sampler, "drop_last"):
|
||||
res.drop_last = getattr(dataloader.batch_sampler, "drop_last")
|
||||
# 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性;
|
||||
else:
|
||||
res.drop_last = False
|
||||
|
||||
return res
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
from typing import Optional, Dict, Union
|
||||
|
||||
from .paddle_driver import PaddleDriver
|
||||
from .utils import replace_batch_sampler, replace_sampler
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.utils import (
|
||||
@ -10,7 +11,7 @@ from fastNLP.core.utils import (
|
||||
get_paddle_device_id,
|
||||
paddle_move_data_to_device,
|
||||
)
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
@ -93,11 +94,8 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
self._test_signature_fn = model.forward
|
||||
|
||||
def setup(self):
|
||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
|
||||
device_id = get_paddle_device_id(self.model_device)
|
||||
if user_visible_devices is not None and user_visible_devices != "":
|
||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
|
||||
device_id = user_visible_devices.split(",")[device_id]
|
||||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
||||
paddle.device.set_device("gpu:0")
|
||||
self.model.to("gpu:0")
|
||||
@ -145,26 +143,25 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
"FastNLP does not support `IteratorDataset` now."
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
dataloader.batch_sampler = dist
|
||||
return dataloader
|
||||
if isinstance(dist, ReproducibleIterator):
|
||||
dataloader.batch_sampler.sampler = dist
|
||||
return dataloader
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleIterator):
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
if reproducible:
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
|
||||
return dataloader
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.sampler, ReproducibleIterator):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
|
||||
return dataloader
|
||||
batch_sampler = re_instantiate_sampler(dataloader.batch_sampler)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
# TODO
|
||||
batch_sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=dataloader.batch_sampler,
|
||||
batch_size=dataloader.batch_sampler.batch_size,
|
||||
drop_last=dataloader.drop_last
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
dataloader.batch_sampler = batch_sampler
|
||||
return dataloader
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
return dataloader
|
||||
|
||||
|
@ -9,7 +9,7 @@ from enum import IntEnum
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call
|
||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -272,11 +272,9 @@ def get_device_from_visible(device: Union[str, int]):
|
||||
else:
|
||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
|
||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
|
||||
if user_visible_devices is not None and user_visible_devices != "":
|
||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
|
||||
idx = user_visible_devices.split(",")[idx]
|
||||
else:
|
||||
idx = str(idx)
|
||||
if user_visible_devices is None:
|
||||
raise RuntimeError("This situation cannot happen, please report a bug to us.")
|
||||
idx = user_visible_devices.split(",")[idx]
|
||||
|
||||
cuda_visible_devices_list = cuda_visible_devices.split(',')
|
||||
assert idx in cuda_visible_devices_list, "Can't find "\
|
||||
@ -285,31 +283,44 @@ def get_device_from_visible(device: Union[str, int]):
|
||||
res = cuda_visible_devices_list.index(idx)
|
||||
return res
|
||||
|
||||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
|
||||
# 拿到实例属性;
|
||||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"):
|
||||
"""
|
||||
利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。
|
||||
考虑了用户自己定制了 DataLoader 的情形。
|
||||
"""
|
||||
# 拿到非下划线开头的实例属性;
|
||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}
|
||||
|
||||
# 拿到 dataloader '__init__' 函数的默认函数签名;
|
||||
# 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型
|
||||
init_params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
|
||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
|
||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
|
||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
|
||||
# 中寻找;
|
||||
# 中寻找;VAR_KEYWORD 代表 **kwargs
|
||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
|
||||
if has_variadic_kwargs:
|
||||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters))
|
||||
del init_params["self"]
|
||||
|
||||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍;
|
||||
# 将同时在实例名和参数名中出现且不是默认值的参数收集起来
|
||||
non_default_params = {name for name, p in init_params.items() if
|
||||
name in instance_attrs and p.default != instance_attrs[name]}
|
||||
# add `dataset` as it might have been replaced with `*args`
|
||||
non_default_params.add("dataset")
|
||||
|
||||
# 收集不是默认值的参数和它的值
|
||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
|
||||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1})
|
||||
# persistent_workers 在类中的对应成员带有下划线,因此添加进来
|
||||
reconstruct_args.update({
|
||||
"batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1,
|
||||
"persistent_workers": dataloader._persistent_workers,
|
||||
})
|
||||
|
||||
# POSITIONAL_OR_KEYWORD 代表一般的参数
|
||||
# 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数
|
||||
# 也即它们没有在初始化函数和实例成员中同时出现
|
||||
required_args = {
|
||||
p.name
|
||||
for p in init_params.values()
|
||||
@ -323,12 +334,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
|
||||
required_args = sorted(required_args)
|
||||
dataloader_self_name = dataloader.__class__.__name__
|
||||
raise Exception(
|
||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
|
||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
|
||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
||||
f"The missing attributes are {required_args}. "
|
||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
|
||||
"manually add the `DistributedBatchSampler` as: "
|
||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
|
||||
)
|
||||
|
||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
|
||||
@ -340,12 +348,33 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"):
|
||||
missing_kwargs = sorted(missing_kwargs)
|
||||
dataloader_self_name = dataloader.__class__.__name__
|
||||
raise Exception(
|
||||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. "
|
||||
f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. "
|
||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
||||
f"The missing arguments are {missing_kwargs}. "
|
||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
|
||||
"manually add the `DistributedBatchSampler` as: "
|
||||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`."
|
||||
)
|
||||
|
||||
return type(dataloader)(**reconstruct_args)
|
||||
|
||||
def replace_sampler(dataloader, new_sampler):
|
||||
"""
|
||||
使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中
|
||||
"""
|
||||
new_batch_sampler = BatchSampler(
|
||||
dataset=dataloader.batch_sampler.dataset,
|
||||
sampler=new_sampler,
|
||||
shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler),
|
||||
batch_size=dataloader.batch_sampler.batch_size,
|
||||
drop_last=dataloader.batch_sampler.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, new_batch_sampler)
|
||||
|
||||
def optimizer_state_to_device(state, device):
|
||||
new_state = {}
|
||||
for name, param in state.items():
|
||||
if isinstance(param, dict):
|
||||
new_state[name] = optimizer_state_to_device(param, device)
|
||||
elif isinstance(param, paddle.Tensor):
|
||||
new_state[name] = paddle_to(param, device).clone()
|
||||
else:
|
||||
new_state[name] = param
|
||||
return new_state
|
||||
|
Loading…
Reference in New Issue
Block a user