mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
7ca1abfba5
@ -1,8 +1,11 @@
|
||||
pipeline {
|
||||
agent none
|
||||
agent any
|
||||
options {
|
||||
timeout(time:30, unit: 'MINUTES')
|
||||
}
|
||||
environment {
|
||||
PJ_NAME = 'fastNLP'
|
||||
POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/14719364-818d-4f88-9057-7c9f0eaaf6ae'
|
||||
POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/2f7122e3-3459-43d2-a9e4-ddd77bfc4282'
|
||||
}
|
||||
stages {
|
||||
stage('Parallel Stages') {
|
||||
@ -15,7 +18,12 @@ pipeline {
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests --durations=0 -m "not (torch or paddle or paddledist or jittor or torchpaddle or torchjittor)"'
|
||||
sh 'pytest ./tests --durations=0 --html=other.html --self-contained-html -m "not (torch or paddle or paddledist or jittor or torchpaddle or torchjittor)"'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv other.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Test Torch-1.11') {
|
||||
@ -26,7 +34,12 @@ pipeline {
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests --durations=0 -m torch'
|
||||
sh 'pytest ./tests/ --durations=0 --html=torch-1.11.html --self-contained-html -m torch'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.11.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Test Torch-1.6') {
|
||||
@ -37,7 +50,12 @@ pipeline {
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests/ --durations=0 -m torch'
|
||||
sh 'pytest ./tests/ --durations=0 --html=torch-1.6.html --self-contained-html -m torch'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.6.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Test Paddle') {
|
||||
@ -48,11 +66,16 @@ pipeline {
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests --durations=0 -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests --durations=0 -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_dist_utils.py --durations=0 --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_fleet.py --durations=0 --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/controllers/test_trainer_paddle.py --durations=0 --co'
|
||||
sh 'pytest ./tests --durations=0 --html=paddle.html --self-contained-html -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests --durations=0 --html=paddle_with_backend.html --self-contained-html -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_dist_utils.py --durations=0 --html=paddle_dist_utils.html --self-contained-html --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_fleet.py --durations=0 --html=paddle_fleet.html --self-contained-html --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/controllers/test_trainer_paddle.py --durations=0 --html=paddle_trainer.html --self-contained-html --co'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv paddle*.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
// stage('Test Jittor') {
|
||||
@ -65,7 +88,7 @@ pipeline {
|
||||
// steps {
|
||||
// // sh 'pip install fitlog'
|
||||
// // sh 'pytest ./tests --html=test_results.html --self-contained-html'
|
||||
// sh 'pytest ./tests --durations=0 -m jittor --co'
|
||||
// sh 'pytest ./tests --durations=0 --html=jittor.html --self-contained-html -m jittor --co'
|
||||
// }
|
||||
// }
|
||||
}
|
||||
@ -77,7 +100,7 @@ pipeline {
|
||||
}
|
||||
success {
|
||||
sh 'post 0'
|
||||
sh 'post github'
|
||||
// sh 'post github'
|
||||
}
|
||||
}
|
||||
}
|
@ -9,7 +9,7 @@ SPHINXPROJ = fastNLP
|
||||
SPHINXEXCLUDE = ../fastNLP/transformers/*
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
PORT = 9000
|
||||
PORT = 8000
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@ -30,6 +30,9 @@ web:
|
||||
dev:
|
||||
make delete && make apidoc && make html && make server
|
||||
|
||||
versions:
|
||||
sphinx-multiversion "$(SOURCEDIR)" "$(BUILDDIR)" && cd build && python -m http.server $(PORT)
|
||||
|
||||
prod:
|
||||
make apidoc && make html
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
sphinx
|
||||
sphinx_rtd_theme
|
||||
sphinx_autodoc_typehints
|
||||
sphinx_autodoc_typehints
|
||||
sphinx-multiversion
|
27
docs/source/_templates/versions.html
Normal file
27
docs/source/_templates/versions.html
Normal file
@ -0,0 +1,27 @@
|
||||
{%- if current_version %}
|
||||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||||
<span class="fa fa-book"> Other Versions</span>
|
||||
v: {{ current_version.name }}
|
||||
<span class="fa fa-caret-down"></span>
|
||||
</span>
|
||||
<div class="rst-other-versions">
|
||||
{%- if versions.tags %}
|
||||
<dl>
|
||||
<dt>Tags</dt>
|
||||
{%- for item in versions.tags %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
{%- if versions.branches %}
|
||||
<dl>
|
||||
<dt>Branches</dt>
|
||||
{%- for item in versions.branches %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
</div>
|
||||
</div>
|
||||
{%- endif %}
|
@ -43,7 +43,8 @@ extensions = [
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx_autodoc_typehints'
|
||||
'sphinx_autodoc_typehints',
|
||||
'sphinx_multiversion',
|
||||
]
|
||||
|
||||
autodoc_default_options = {
|
||||
@ -116,7 +117,11 @@ html_static_path = ['_static']
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
html_sidebars = {
|
||||
'**': [
|
||||
'versions.html',
|
||||
],
|
||||
}
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
@ -168,6 +173,8 @@ texinfo_documents = [
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
# -- Options for Multiversions ----------------------------------------------
|
||||
smv_latest_version = 'dev0.8.0'
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
def maybe_skip_member(app, what, name, obj, skip, options):
|
||||
|
@ -54,18 +54,9 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
if model_save_fn is not None:
|
||||
assert save_folder is not None, "When passing `model_save_fn`, `save_folder` must be provided."
|
||||
|
||||
if save_folder is not None:
|
||||
if save_folder:
|
||||
if os.path.exists(save_folder):
|
||||
assert os.path.isdir(save_folder), f"`save_folder` must be a directory."
|
||||
else:
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME))
|
||||
self.real_save_folder = os.path.join(save_folder, 'best_so_far')
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
os.makedirs(self.real_save_folder, exist_ok=True)
|
||||
else: # 创建出一个 stringio
|
||||
self.real_save_folder = None
|
||||
self.buffer = BytesIO()
|
||||
assert os.path.isdir(save_folder), f"`save_folder={save_folder}` must be a directory."
|
||||
|
||||
self.save_folder = save_folder
|
||||
self.only_state_dict = only_state_dict
|
||||
@ -73,21 +64,37 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
self.model_load_fn = model_load_fn
|
||||
self.delete_after_after = delete_after_train
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1:
|
||||
# 如果需要保存,但是又是不是 fastNLP 拉起的, 需要同步一下 folder
|
||||
try:
|
||||
self.real_save_folder = driver.broadcast_object(self.real_save_folder, src=0, group=None)
|
||||
logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.")
|
||||
except NotImplementedError:
|
||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to "
|
||||
f"save best model when launch using module.")
|
||||
def prepare_save_folder(self, trainer):
|
||||
if not hasattr(self, 'real_save_folder'):
|
||||
if self.save_folder is not None:
|
||||
if not os.path.exists(self.save_folder):
|
||||
os.makedirs(self.save_folder, exist_ok=True)
|
||||
self.save_folder = os.path.join(self.save_folder, os.environ.get(FASTNLP_LAUNCH_TIME))
|
||||
self.real_save_folder = os.path.join(self.save_folder, 'best_so_far')
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
os.makedirs(self.real_save_folder, exist_ok=True)
|
||||
if self.save_folder is not None and trainer.driver.is_distributed() and int(
|
||||
os.environ.get(FASTNLP_BACKEND_LAUNCH, 0)) == 1:
|
||||
trainer.driver.barrier()
|
||||
try:
|
||||
self.real_save_folder = trainer.driver.broadcast_object(self.real_save_folder, src=0, group=None)
|
||||
logger.debug(
|
||||
f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.")
|
||||
except NotImplementedError:
|
||||
raise RuntimeError(
|
||||
f"Currently {trainer.driver.__class__.__name__} does not support using `save_folder` to "
|
||||
f"save best model when launch using module.")
|
||||
else: # 创建出一个 stringio
|
||||
self.real_save_folder = None
|
||||
self.buffer = BytesIO()
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
super().on_after_trainer_initialized(trainer, driver)
|
||||
self.encounter_exception = False
|
||||
|
||||
def on_evaluate_end(self, trainer, results):
|
||||
if self.is_better_results(results, keep_if_better=True):
|
||||
self.prepare_save_folder(trainer)
|
||||
if self.real_save_folder:
|
||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
|
||||
model_save_fn=self.model_save_fn)
|
||||
@ -103,8 +110,7 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
|
||||
model_load_fn=self.model_load_fn)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...")
|
||||
logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...")
|
||||
self.buffer.seek(0)
|
||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
|
||||
if self.delete_after_after:
|
||||
@ -119,7 +125,7 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
self.encounter_exception = True
|
||||
|
||||
def _delete_folder(self):
|
||||
if self.real_save_folder:
|
||||
if getattr(self, 'real_save_folder', None):
|
||||
logger.info(f"Deleting {self.real_save_folder}...")
|
||||
shutil.rmtree(self.real_save_folder, ignore_errors=True)
|
||||
try:
|
||||
|
@ -3,7 +3,11 @@ __all__ = [
|
||||
]
|
||||
from typing import Union, List
|
||||
from ..callback import Callback
|
||||
|
||||
from ...drivers.torch_driver.fairscale import FairScaleDriver
|
||||
from ...drivers.torch_driver import TorchDriver
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE
|
||||
if _NEED_IMPORT_FAIRSCALE:
|
||||
from fairscale.nn import FullyShardedDataParallel
|
||||
|
||||
class TorchGradClipCallback(Callback):
|
||||
r"""
|
||||
@ -35,15 +39,20 @@ class TorchGradClipCallback(Callback):
|
||||
else:
|
||||
self.parameters = None
|
||||
self.clip_value = clip_value
|
||||
self.clip_type = clip_type
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \
|
||||
assert isinstance(driver, TorchDriver), f"Callback:{self.__class__.__name__} only supports torch " \
|
||||
f"related drivers for now."
|
||||
parameters = []
|
||||
for optimizer in trainer.driver.optimizers:
|
||||
for param_group in optimizer.param_groups:
|
||||
parameters.extend(param_group['params'])
|
||||
self.parameters = parameters
|
||||
if isinstance(trainer.driver, FairScaleDriver):
|
||||
if isinstance(trainer.driver.model, FullyShardedDataParallel) and self.clip_type == 'norm':
|
||||
self.clip_fun = trainer.driver.model.clip_grad_norm_
|
||||
|
||||
assert len(self.parameters), "There is no parameters need to be clipped."
|
||||
|
||||
def on_before_optimizers_step(self, trainer, optimizers):
|
||||
|
@ -58,7 +58,7 @@ class TrainBatchLoop(Loop):
|
||||
trainer.on_train_batch_end()
|
||||
except BaseException as e:
|
||||
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)):
|
||||
logger.error(f"Exception happens when running on samples: {indices}")
|
||||
logger.error(f"Exception happens when training on samples: {indices}")
|
||||
raise e
|
||||
trainer.step_evaluate()
|
||||
trainer.batch_idx_in_epoch = 0
|
||||
|
@ -267,7 +267,8 @@ class Trainer(TrainerEventTrigger):
|
||||
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
|
||||
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
|
||||
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
|
||||
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||
* non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
|
||||
* gradscaler_kwargs -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数。
|
||||
* *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数:
|
||||
|
||||
* fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括:
|
||||
@ -494,9 +495,6 @@ class Trainer(TrainerEventTrigger):
|
||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
|
||||
reproducible=self.callback_manager._need_reproducible_sampler)
|
||||
|
||||
_torch_kwargs = kwargs.get("torch_kwargs", {})
|
||||
self.set_grad_to_none = _torch_kwargs.get("set_grad_to_none", True)
|
||||
|
||||
self.evaluate_batch_step_fn = evaluate_batch_step_fn
|
||||
self.kwargs = kwargs
|
||||
|
||||
@ -596,7 +594,7 @@ class Trainer(TrainerEventTrigger):
|
||||
try:
|
||||
self.on_train_begin()
|
||||
self.driver.barrier()
|
||||
self.driver.zero_grad(self.set_grad_to_none)
|
||||
self.driver.zero_grad()
|
||||
while self.cur_epoch_idx < self.n_epochs:
|
||||
# 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save
|
||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
|
||||
@ -1236,7 +1234,7 @@ class Trainer(TrainerEventTrigger):
|
||||
"""
|
||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
|
||||
self.on_before_zero_grad(self.optimizers)
|
||||
self.driver.zero_grad(self.set_grad_to_none)
|
||||
self.driver.zero_grad()
|
||||
self.on_after_zero_grad(self.optimizers)
|
||||
|
||||
def step(self):
|
||||
|
@ -198,12 +198,11 @@ class Driver(ABC):
|
||||
raise NotImplementedError("Each specific driver should implemented its own `step` function.")
|
||||
|
||||
@abstractmethod
|
||||
def zero_grad(self, set_to_none: bool = False):
|
||||
def zero_grad(self):
|
||||
r"""
|
||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
|
||||
注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积;
|
||||
|
||||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;
|
||||
"""
|
||||
raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.")
|
||||
|
||||
|
@ -46,7 +46,7 @@ class JittorSingleDriver(JittorDriver):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.backward(loss)
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
def zero_grad(self):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -199,7 +199,7 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
paddle_kwargs = kwargs.get("paddle_kwargs", {})
|
||||
|
||||
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {})
|
||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
|
||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__)
|
||||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档
|
||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy())
|
||||
self.is_collective = self._fleet_kwargs.pop("is_collective", True)
|
||||
|
@ -82,13 +82,7 @@ class PaddleDriver(Driver):
|
||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
|
||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False):
|
||||
r"""
|
||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 ``optimizers`` 来将梯度置零;
|
||||
注意梯度累积不需要在这里实现,:class:`~fastNLP.core.Trainer` 已经在内部实现了梯度累积;
|
||||
|
||||
:param set_to_none: 用来判断是否需要将梯度直接置为 ``None``;在 **PaddlePaddle** 中这个参数无效。
|
||||
"""
|
||||
def zero_grad(self):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.clear_grad()
|
||||
|
||||
@ -194,7 +188,7 @@ class PaddleDriver(Driver):
|
||||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
|
||||
paddle.jit.save(model, filepath, input_spec)
|
||||
|
||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
|
||||
model = self.unwrap_model()
|
||||
if isinstance(filepath, Path):
|
||||
filepath = str(filepath)
|
||||
@ -274,21 +268,10 @@ class PaddleDriver(Driver):
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
|
||||
if only_state_dict:
|
||||
logger.debug("Save model state dict.")
|
||||
else:
|
||||
logger.debug("Save model.")
|
||||
|
||||
# 3. 保存 optimizers 的状态;
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
|
||||
states["optimizers_state_dict"] = self.get_optimizer_state()
|
||||
logger.debug("Save optimizer state dict.")
|
||||
states["optimizers_state_dict"] = optimizers_state_dict
|
||||
|
||||
# 4.保存fp16的状态
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
@ -297,34 +280,42 @@ class PaddleDriver(Driver):
|
||||
|
||||
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
|
||||
|
||||
def get_optimizer_state(self):
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
|
||||
return optimizers_state_dict
|
||||
|
||||
def load_optimizer_state(self, states):
|
||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
|
||||
f"checkpoint it is:{len(states)}"
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer.set_state_dict(states[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
|
||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
|
||||
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
optimizers_state_dict = states.pop("optimizers_state_dict")
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
self.load_optimizer_state(optimizers_state_dict)
|
||||
|
||||
# 2. 加载模型状态;
|
||||
if should_load_model:
|
||||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
|
||||
if only_state_dict:
|
||||
logger.debug("Load model state dict...")
|
||||
else:
|
||||
logger.debug("Load model...")
|
||||
|
||||
# 3. 加载fp16的状态;
|
||||
if "grad_scaler_state_dict" in states:
|
||||
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
|
||||
if isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
|
||||
self.grad_scaler = _grad_scaler()
|
||||
self.fp16 = True
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
elif not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
|
||||
f"the training process may be unstable.")
|
||||
@ -347,7 +338,7 @@ class PaddleDriver(Driver):
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
)
|
||||
sampler.load_state_dict(states["sampler_states"])
|
||||
sampler.load_state_dict(states.pop("sampler_states"))
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 5. 修改 trainer_state.batch_idx_in_epoch
|
||||
|
@ -304,11 +304,11 @@ class TorchDDPDriver(TorchDriver):
|
||||
self.global_rank = 0
|
||||
|
||||
self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {})
|
||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__)
|
||||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__)
|
||||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None:
|
||||
logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set "
|
||||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
|
||||
" to 'False' to avoid redundant data translation between different processes.")
|
||||
" to 'False' to avoid redundant data communication between different processes.")
|
||||
|
||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
|
||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
|
||||
@ -471,7 +471,7 @@ class TorchDDPDriver(TorchDriver):
|
||||
self._global_rank = rank
|
||||
|
||||
@property
|
||||
def local_rank(self) -> int:
|
||||
def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响
|
||||
return int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
@property
|
||||
|
307
fastNLP/core/drivers/torch_driver/fairscale.py
Normal file
307
fastNLP/core/drivers/torch_driver/fairscale.py
Normal file
@ -0,0 +1,307 @@
|
||||
__all__ = [
|
||||
'FairScaleDriver'
|
||||
]
|
||||
from typing import List, Sequence, Union, Dict, Mapping
|
||||
from pathlib import Path
|
||||
import os
|
||||
import functools
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE
|
||||
if _NEED_IMPORT_FAIRSCALE:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.nn import ShardedDataParallel
|
||||
from fairscale.nn import FullyShardedDataParallel
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from fairscale.nn.wrap import auto_wrap, enable_wrap, default_auto_wrap_policy
|
||||
|
||||
from ...log import logger
|
||||
from .utils import reset_seed, _DDPWrappingModel
|
||||
|
||||
from .ddp import TorchDDPDriver
|
||||
from .torch_driver import TorchDriver
|
||||
from .utils import _build_fp16_env
|
||||
from ....envs.distributed import all_rank_call_context
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
|
||||
from .utils import optimizer_state_to_device
|
||||
|
||||
|
||||
class FairScaleDriver(TorchDDPDriver):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
parallel_device: Union[List["torch.device"], "torch.device"],
|
||||
is_pull_by_torch_run = False,
|
||||
fp16: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported."
|
||||
assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user."
|
||||
self._fairscale_kwargs = kwargs.get('fairscale_kwargs', {})
|
||||
self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp
|
||||
if self.fs_type == 'fsdp':
|
||||
self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True)
|
||||
# 将最顶上的进行初始化
|
||||
kwargs.pop('torch_kwargs', None)
|
||||
TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=self._fairscale_kwargs, **kwargs)
|
||||
self.is_pull_by_torch_run = is_pull_by_torch_run
|
||||
assert self.fs_type in ['ddp', 'sdp', 'fsdp']
|
||||
self._oss_kwargs = self._fairscale_kwargs.get('oss_kwargs', {}) # 仅在 ddp 和 sdp 下有使用到
|
||||
self._sdp_kwargs = self._fairscale_kwargs.get('sdp_kwargs', {})
|
||||
self._fdsp_kwargs = self._fairscale_kwargs.get('fsdp_kwargs', {})
|
||||
self._ddp_kwargs = self._fairscale_kwargs.get('ddp_kwargs', {})
|
||||
|
||||
if self.fs_type == 'ddp' or fp16 is False:
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
|
||||
self.grad_scaler = _grad_scaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {}))
|
||||
else:
|
||||
self.auto_cast, self.grad_scaler = torch.cuda.amp.autocast, \
|
||||
ShardedGradScaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {}))
|
||||
|
||||
self.parallel_device = parallel_device
|
||||
if is_pull_by_torch_run:
|
||||
self.model_device = parallel_device
|
||||
else:
|
||||
self.model_device = parallel_device[self.local_rank]
|
||||
|
||||
self.outside_ddp = False # 不允许在外部初始化
|
||||
self._data_device = kwargs.get("data_device", None)
|
||||
if isinstance(self._data_device, int):
|
||||
if self._data_device < 0:
|
||||
raise ValueError("Parameter `data_device` can not be smaller than 0.")
|
||||
_could_use_device_num = torch.cuda.device_count()
|
||||
if self._data_device >= _could_use_device_num:
|
||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
|
||||
self._data_device = torch.device(f"cuda:{self._data_device}")
|
||||
elif isinstance(self._data_device, str):
|
||||
self._data_device = torch.device(self._data_device)
|
||||
elif self._data_device is not None and not isinstance(self._data_device, torch.device):
|
||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
|
||||
|
||||
self._master_port = None
|
||||
# world_size 表示的就是全局的显卡的数量;
|
||||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
|
||||
self.global_rank = 0
|
||||
|
||||
if self.fs_type == 'ddp':
|
||||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None:
|
||||
logger.info("Notice your model has buffers and you are using `FairScaleDriver`, but you do not set "
|
||||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
|
||||
" to 'False' to avoid redundant data communication between different processes.")
|
||||
|
||||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
|
||||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
|
||||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
|
||||
os.makedirs(self.output_from_new_proc, exist_ok=True)
|
||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
|
||||
|
||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
|
||||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
|
||||
|
||||
def setup(self):
|
||||
r"""
|
||||
准备分布式环境,该函数主要做以下两件事情:
|
||||
|
||||
1. 开启多进程,每个 gpu 设备对应单独的一个进程;
|
||||
2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型;
|
||||
"""
|
||||
if self._has_setup:
|
||||
return
|
||||
self._has_setup = True
|
||||
if self.is_pull_by_torch_run:
|
||||
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
|
||||
self.world_size = int(os.environ.get("WORLD_SIZE"))
|
||||
self.global_rank = int(os.environ.get("RANK"))
|
||||
reset_seed()
|
||||
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}")
|
||||
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(
|
||||
backend="nccl", rank=self.global_rank, world_size=self.world_size
|
||||
)
|
||||
|
||||
os.environ["fastnlp_torch_launch_not_ddp"] = "yes"
|
||||
else:
|
||||
if not dist.is_initialized():
|
||||
# 这里主要的问题在于要区分 rank0 和其它 rank 的情况;
|
||||
self.world_size = len(self.parallel_device)
|
||||
self.open_subprocess()
|
||||
self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的;
|
||||
reset_seed()
|
||||
dist.init_process_group(
|
||||
backend="nccl", rank=self.global_rank, world_size=self.world_size
|
||||
)
|
||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver;
|
||||
else:
|
||||
# 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在
|
||||
# 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的;
|
||||
pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK])
|
||||
if pre_num_processes != len(self.parallel_device):
|
||||
raise RuntimeError(
|
||||
"Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not"
|
||||
"allowed that your second `TorchDDPDriver` has a new setting of parameters "
|
||||
"`num_nodes` and `num_processes`.")
|
||||
self.world_size = dist.get_world_size()
|
||||
self.global_rank = dist.get_rank()
|
||||
|
||||
torch.cuda.set_device(self.model_device)
|
||||
if self.fs_type != 'fsdp':
|
||||
self.model.to(self.model_device)
|
||||
self.configure_ddp()
|
||||
|
||||
self.barrier()
|
||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作;
|
||||
self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device))
|
||||
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
|
||||
if local_world_size is None:
|
||||
local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device)
|
||||
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
|
||||
local_world_size = local_world_size.tolist() + 1
|
||||
|
||||
node_rank = self.global_rank // local_world_size
|
||||
self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size]
|
||||
self._pids = self.tensor_to_numeric(self._pids)
|
||||
|
||||
def configure_ddp(self):
|
||||
model = _DDPWrappingModel(self.model)
|
||||
if self.fs_type == 'ddp':
|
||||
self.model = DistributedDataParallel(
|
||||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
|
||||
model, device_ids=[self.model_device.index],
|
||||
**self._ddp_kwargs
|
||||
)
|
||||
elif self.fs_type == 'sdp':
|
||||
sdp_kwargs = self._sdp_kwargs
|
||||
sdp_kwargs = {**sdp_kwargs, 'module': model}
|
||||
sdp_kwargs['reduce_fp16'] = sdp_kwargs.get('reduce_fp16', self.fp16)
|
||||
oss_lst = []
|
||||
for optimizer in self.optimizers:
|
||||
oss = OSS(optimizer.param_groups, optim=type(optimizer), **optimizer.defaults)
|
||||
oss_lst.append(oss)
|
||||
sdp_kwargs['sharded_optimizer'] = oss_lst
|
||||
sdp_kwargs['warn_on_trainable_params_changed'] = sdp_kwargs.get('warn_on_trainable_params_changed', False)
|
||||
self.model = ShardedDataParallel(**sdp_kwargs)
|
||||
self.optimizers = oss_lst
|
||||
else:
|
||||
assert len(self.optimizers) == 1, "When fs_type='fsdp', only one optimizer is allowed."
|
||||
optimizer = self.optimizers[0]
|
||||
assert len(optimizer.param_groups) == 1, "Cannot assign parameter specific optimizer parameter for 'fsdp'."
|
||||
fsdp_kwargs = self._fdsp_kwargs
|
||||
fsdp_kwargs['mixed_precision'] = self.fp16
|
||||
fsdp_kwargs['state_dict_on_rank_0_only'] = fsdp_kwargs.get('state_dict_on_rank_0_only', True)
|
||||
fsdp_kwargs['state_dict_device'] = fsdp_kwargs.get('state_dict_device', torch.device('cpu'))
|
||||
fsdp_kwargs['compute_device'] = fsdp_kwargs.get('compute_device', self.model_device)
|
||||
optimizer = self.optimizers[0]
|
||||
# wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=1e6)
|
||||
# with enable_wrap(wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy,
|
||||
# **fsdp_kwargs):
|
||||
# model = auto_wrap(model)
|
||||
fsdp_kwargs = {**fsdp_kwargs, 'module': model}
|
||||
self.model = None # 释放掉
|
||||
self.model = FullyShardedDataParallel(**fsdp_kwargs).to(self.model_device)
|
||||
self.optimizers = type(optimizer)(self.model.parameters(), **optimizer.defaults)
|
||||
|
||||
self._has_ddpwrapped = True
|
||||
|
||||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
|
||||
"""
|
||||
保存当前 driver 的模型到 folder 下。
|
||||
|
||||
:param filepath: 保存到哪个文件夹;
|
||||
:param only_state_dict: 是否只保存权重;
|
||||
:return:
|
||||
"""
|
||||
if self.fs_type in ('ddp', 'sdp'):
|
||||
model = self.model.module.model
|
||||
|
||||
if only_state_dict:
|
||||
if self.fs_type != 'fsdp':
|
||||
if self.local_rank == 0:
|
||||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
|
||||
else:
|
||||
# 所有 rank 都需要调用
|
||||
states = self.model.state_dict()
|
||||
if self.local_rank == 0:
|
||||
states = {key[len('model.'):]:value for key, value in states.items()} # 这里需要去掉那个 _wrap 的 key
|
||||
if self.local_rank == 0: #
|
||||
torch.save(states, filepath)
|
||||
elif self.fs_type == 'fsdp':
|
||||
raise RuntimeError("When fs_type='fsdp', only `only_state_dict=True` is allowed.")
|
||||
else:
|
||||
if self.local_rank == 0:
|
||||
torch.save(model, filepath)
|
||||
|
||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
"""
|
||||
从 folder 中加载权重并赋值到当前 driver 的模型上。
|
||||
|
||||
:param filepath: 加载权重或模型的路径
|
||||
:param load_state_dict: 保存的内容是否只是权重。
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
states = torch.load(filepath, map_location='cpu')
|
||||
if isinstance(states, dict) and only_state_dict is False:
|
||||
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
|
||||
f"`only_state_dict=True`")
|
||||
elif not isinstance(states, dict) and only_state_dict is True:
|
||||
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
|
||||
f"`only_state_dict=False`")
|
||||
if not isinstance(states, Mapping):
|
||||
states = states.state_dict()
|
||||
|
||||
if self.fs_type in ('ddp', 'sdp'):
|
||||
model = self.model.module.model
|
||||
else:
|
||||
model = self.model
|
||||
states = {f'model.{k}':v for k, v in states.items()}
|
||||
|
||||
model.load_state_dict(states)
|
||||
|
||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
if self.fs_type == 'fsdp':
|
||||
if should_save_model is False:
|
||||
logger.warning("When save model using fs_type='fsdp', please make sure use "
|
||||
"`with trainer.driver.model.summon_full_params():` context to gather all parameters.")
|
||||
with all_rank_call_context():
|
||||
super().save_checkpoint(folder=folder, states=states, dataloader=dataloader, only_state_dict=only_state_dict,
|
||||
should_save_model=should_save_model, **kwargs)
|
||||
else:
|
||||
super().save_checkpoint(folder=folder, states=states, dataloader=dataloader,
|
||||
only_state_dict=only_state_dict, should_save_model=should_save_model, **kwargs)
|
||||
|
||||
def get_optimizer_state(self):
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
if self.fs_type == 'fsdp':
|
||||
optimizer_state = self.model.gather_full_optim_state_dict(optimizer)
|
||||
elif self.fs_type == 'sdp':
|
||||
optimizer.consolidate_state_dict(recipient_rank=0)
|
||||
else:
|
||||
optimizer_state = optimizer.state_dict()
|
||||
if self.local_rank == 0:
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
return optimizers_state_dict
|
||||
|
||||
def load_optimizer_state(self, states):
|
||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
|
||||
f"checkpoint it is:{len(states)}"
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
state = states[f'optimizer{i}']
|
||||
if self.fs_type == 'fsdp':
|
||||
state = self.model.get_shard_from_optim_state_dict(state)
|
||||
optimizer.load_state_dict(state)
|
||||
|
||||
logger.debug("Load optimizer state dict.")
|
||||
|
||||
def unwrap_model(self):
|
||||
r"""
|
||||
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹;
|
||||
"""
|
||||
return self.model.module.model
|
@ -1,63 +0,0 @@
|
||||
from typing import List
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE
|
||||
if _NEED_IMPORT_FAIRSCALE:
|
||||
import torch
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
from fairscale.optim import OSS
|
||||
|
||||
__all__ = [
|
||||
'ShardedDriver'
|
||||
]
|
||||
|
||||
from .ddp import TorchDDPDriver
|
||||
|
||||
|
||||
# todo 注意 fairscale 现在几乎所有的功能都没有实现;
|
||||
# TODO:预跑前后对模型和 optimizers 的支持;
|
||||
# TODO:fairscale 的 fp16 额外的处理;
|
||||
class ShardedDriver(TorchDDPDriver):
|
||||
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
parallel_device: List["torch.device"],
|
||||
num_nodes: int = 1,
|
||||
fp16: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super(ShardedDriver, self).__init__(
|
||||
model=model,
|
||||
parallel_device=parallel_device,
|
||||
num_nodes=num_nodes,
|
||||
fp16=fp16,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def configure_ddp(self):
|
||||
if "reduce_buffer_size" not in self._ddp_kwargs:
|
||||
# For multi-node training, enabling bucketing will improve performance.
|
||||
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
|
||||
|
||||
self.optimizers = self._wrap_optimizers(self.optimizers)
|
||||
self.model = ShardedDataParallel(self.model, sharded_optimizer=self.optimizers, **self._ddp_kwargs)
|
||||
|
||||
|
||||
def _wrap_optimizers(self, optimizers) -> List["OSS"]:
|
||||
# TODO:之后得去研究一下 pytorch lightning 为什么这样写,我们是不是也需要这样写;
|
||||
# if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
# return optimizers
|
||||
|
||||
return self._reinit_optimizers_with_oss(optimizers)
|
||||
|
||||
def _reinit_optimizers_with_oss(self, optimizers) -> List["OSS"]:
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if not isinstance(optimizer, OSS):
|
||||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
|
||||
# TODO:具体细节见 pytorch lightning 的这一函数,主要的点在于加入 fp16 相关的一些东西;
|
||||
optimizers[x] = zero_optimizer
|
||||
del optimizer
|
||||
return optimizers
|
||||
|
@ -7,11 +7,14 @@ if _NEED_IMPORT_TORCH:
|
||||
from .torch_driver import TorchDriver
|
||||
from .single_device import TorchSingleDriver
|
||||
from .ddp import TorchDDPDriver
|
||||
from .fairscale import FairScaleDriver
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
|
||||
from pkg_resources import parse_version
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]],
|
||||
model: "torch.nn.Module", **kwargs) -> TorchDriver:
|
||||
r"""
|
||||
@ -23,13 +26,20 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
|
||||
|
||||
:return: 返回一个 :class:`~fastNLP.core.TorchSingleDriver` 或 :class:`~fastNLP.core.TorchDDPDriver` 实例;
|
||||
"""
|
||||
if parse_version(torch.__version__) < parse_version('1.6'):
|
||||
raise RuntimeError(f"Pytorch(current version:{torch.__version__}) need to be older than 1.6.")
|
||||
# world_size 和 rank
|
||||
if FASTNLP_BACKEND_LAUNCH in os.environ:
|
||||
if device is not None:
|
||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
|
||||
"up your script. And we will directly get the local device via "
|
||||
"`os.environ['LOCAL_RANK']`.", once=True)
|
||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
|
||||
if driver == 'fairscale':
|
||||
return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
|
||||
is_pull_by_torch_run=True, **kwargs)
|
||||
else:
|
||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
|
||||
is_pull_by_torch_run=True, **kwargs)
|
||||
|
||||
if driver not in {"torch", "fairscale"}:
|
||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].")
|
||||
@ -67,13 +77,10 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
|
||||
else:
|
||||
return TorchDDPDriver(model, device, **kwargs)
|
||||
elif driver == "fairscale":
|
||||
raise NotImplementedError("`fairscale` is not support right now.")
|
||||
# if not isinstance(device, List):
|
||||
# if device.type == 'cpu':
|
||||
# raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.")
|
||||
# log.info("Notice you are using `fairscale` driver, but your chosen `device` is only one gpu, we will"
|
||||
# "still use `fairscale` for you, but if you mean using `TorchSingleDriver`, you should "
|
||||
# "choose `torch` driver.")
|
||||
# return ShardedDriver(model, [device], **kwargs)
|
||||
# else:
|
||||
# return ShardedDriver(model, device, **kwargs)
|
||||
if not isinstance(device, List):
|
||||
if device.type == 'cpu':
|
||||
raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.")
|
||||
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.")
|
||||
return FairScaleDriver(model, [device], **kwargs)
|
||||
else:
|
||||
return FairScaleDriver(model, device, **kwargs)
|
@ -1,7 +1,6 @@
|
||||
import os
|
||||
from typing import Union, Dict, Optional, Callable
|
||||
from functools import partial
|
||||
from pkg_resources import parse_version
|
||||
import numpy as np
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
@ -52,23 +51,23 @@ class TorchDriver(Driver):
|
||||
super(TorchDriver, self).__init__(model)
|
||||
|
||||
""" 进行 fp16 的设置 """
|
||||
self._torch_kwargs = kwargs.get("torch_kwargs", {})
|
||||
|
||||
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里;
|
||||
self.fp16 = fp16
|
||||
if parse_version(torch.__version__) < parse_version('1.6'):
|
||||
raise RuntimeError(f"Pytorch({torch.__version__}) need to be older than 1.6.")
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
|
||||
self.grad_scaler = _grad_scaler()
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16)
|
||||
self.grad_scaler = _grad_scaler(**self._torch_kwargs.get('gradscaler_kwargs', {}))
|
||||
self.set_grad_to_none = self._torch_kwargs.get('set_grad_to_none')
|
||||
|
||||
self._torch_kwargs = kwargs.get("torch_kwargs", {})
|
||||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数;
|
||||
self.non_blocking = self._torch_kwargs.get("torch_non_blocking", True)
|
||||
self.non_blocking = self._torch_kwargs.get("non_blocking", True)
|
||||
|
||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
|
||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False):
|
||||
def zero_grad(self):
|
||||
for optimizer in self.optimizers:
|
||||
self._clear_grad(optimizer, set_to_none)
|
||||
self._clear_grad(optimizer, self.set_grad_to_none)
|
||||
|
||||
def _clear_grad(self, optimizer, set_to_none):
|
||||
param_groups = optimizer.param_groups
|
||||
@ -178,7 +177,7 @@ class TorchDriver(Driver):
|
||||
else:
|
||||
torch.save(model, filepath)
|
||||
|
||||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
|
||||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
|
||||
"""
|
||||
从 folder 中加载权重并赋值到当前 driver 的模型上。
|
||||
|
||||
@ -195,10 +194,9 @@ class TorchDriver(Driver):
|
||||
elif not isinstance(res, dict) and only_state_dict is True:
|
||||
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
|
||||
f"`only_state_dict=False`")
|
||||
if only_state_dict:
|
||||
model.load_state_dict(res)
|
||||
else:
|
||||
model.load_state_dict(res.state_dict())
|
||||
if not isinstance(res, dict):
|
||||
res = res.state_dict()
|
||||
model.load_state_dict(res)
|
||||
|
||||
@rank_zero_call
|
||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
@ -246,25 +244,13 @@ class TorchDriver(Driver):
|
||||
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
model = self.unwrap_model()
|
||||
if not os.path.exists(folder):
|
||||
os.mkdir(folder)
|
||||
if only_state_dict:
|
||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
|
||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失;
|
||||
torch.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME))
|
||||
logger.debug("Save model state dict")
|
||||
else:
|
||||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME))
|
||||
logger.debug("Save model")
|
||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
|
||||
self.save_model(model_path, only_state_dict=only_state_dict)
|
||||
|
||||
# 3. 保存 optimizers 的状态;
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
optimizers_state_dict = self.get_optimizer_state()
|
||||
|
||||
# 4. 保存fp16的状态
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
@ -275,38 +261,42 @@ class TorchDriver(Driver):
|
||||
states["optimizers_state_dict"] = optimizers_state_dict
|
||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
def get_optimizer_state(self):
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
return optimizers_state_dict
|
||||
|
||||
def load_optimizer_state(self, states):
|
||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
|
||||
f"checkpoint it is:{len(states)}"
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
optimizer.load_state_dict(states[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
|
||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
optimizers_state_dict = states.pop("optimizers_state_dict")
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
self.load_optimizer_state(optimizers_state_dict)
|
||||
|
||||
# 2. 加载模型状态;
|
||||
if should_load_model:
|
||||
model = self.unwrap_model()
|
||||
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu')
|
||||
if only_state_dict:
|
||||
model.load_state_dict(res)
|
||||
logger.debug("Load model state dict...")
|
||||
else:
|
||||
model.load_state_dict(res.state_dict())
|
||||
logger.debug("Load model...")
|
||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
|
||||
|
||||
# 3. 加载fp16的状态
|
||||
if "grad_scaler_state_dict" in states:
|
||||
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
|
||||
if isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
|
||||
self.grad_scaler = _grad_scaler()
|
||||
self.fp16 = True
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
elif not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
|
||||
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
|
||||
f"the training process may be unstable.")
|
||||
|
||||
# 4. 恢复 sampler 的状态;
|
||||
|
@ -5,6 +5,7 @@ __all__ = [
|
||||
from typing import Union, List
|
||||
from collections import Counter
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from .metric import Metric
|
||||
from .backend import Backend
|
||||
@ -132,10 +133,10 @@ class ClassifyFPreRecMetric(Metric):
|
||||
seq_len = self.tensor2numpy(seq_len)
|
||||
|
||||
if seq_len is not None and target.ndim > 1:
|
||||
max_len = target.ndim[-1]
|
||||
max_len = target.shape[-1]
|
||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
|
||||
else:
|
||||
masks = None
|
||||
masks = np.ones_like(target)
|
||||
|
||||
if pred.ndim == target.ndim:
|
||||
if len(pred.flatten()) != len(target.flatten()):
|
||||
@ -143,7 +144,6 @@ class ClassifyFPreRecMetric(Metric):
|
||||
f" while target have element numbers:{len(pred.flatten())}, "
|
||||
f"pred have element numbers: {len(target.flatten())}")
|
||||
|
||||
pass
|
||||
elif pred.ndim == target.ndim + 1:
|
||||
pred = pred.argmax(axis=-1)
|
||||
if seq_len is None and target.ndim > 1:
|
||||
@ -152,11 +152,9 @@ class ClassifyFPreRecMetric(Metric):
|
||||
raise RuntimeError(f"when pred have "
|
||||
f"size:{pred.shape}, target should have size: {pred.shape} or "
|
||||
f"{pred.shape[:-1]}, got {target.shape}.")
|
||||
if masks is not None:
|
||||
target = target * masks
|
||||
pred = pred * masks
|
||||
target_idxes = set(target.reshape(-1).tolist())
|
||||
|
||||
target_idxes = set(target.reshape(-1).tolist()+pred.reshape(-1).tolist())
|
||||
for target_idx in target_idxes:
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item()
|
||||
self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item()
|
||||
self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item()
|
||||
self._fn[target_idx] += ((pred != target_idx) * (target == target_idx) * masks).sum().item()
|
||||
|
@ -227,7 +227,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
|
||||
raise e
|
||||
|
||||
|
||||
def check_user_specific_params(user_params: Dict, fn: Callable):
|
||||
def check_user_specific_params(user_params: Dict, fn: Callable, fn_name=None):
|
||||
"""
|
||||
该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况;
|
||||
主要作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误;
|
||||
@ -235,13 +235,16 @@ def check_user_specific_params(user_params: Dict, fn: Callable):
|
||||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字,
|
||||
``value`` 为每一个参数的值;
|
||||
:param fn: 将要被调用的函数;
|
||||
:param fn_name: 在打印提示信息是如何显示函数名
|
||||
:return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值;
|
||||
"""
|
||||
if fn_name is None:
|
||||
fn_name = fn.__name__
|
||||
|
||||
fn_arg_names = get_fn_arg_names(fn)
|
||||
for arg_name, arg_value in user_params.items():
|
||||
if arg_name not in fn_arg_names:
|
||||
logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.")
|
||||
logger.rank_zero_warning(f"Notice parameter `{arg_name}` may not be used by `{fn_name}`.")
|
||||
return user_params
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ else:
|
||||
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale.nn") and 'torch' in need_import
|
||||
_NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and 'torch' in need_import
|
||||
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import
|
||||
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import
|
||||
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import
|
||||
|
@ -277,13 +277,12 @@ def test_trainer_specific_params_1(
|
||||
|
||||
model_wo_auto_param_call=True,
|
||||
torch_kwargs={
|
||||
"torch_non_blocking": False,
|
||||
"non_blocking": False,
|
||||
"set_grad_to_none": True
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
assert trainer.set_grad_to_none is True
|
||||
assert trainer.driver.non_blocking is False
|
||||
assert trainer.driver.wo_auto_param_call is True
|
||||
|
||||
@ -320,13 +319,11 @@ def test_trainer_specific_params_2(
|
||||
"broadcast_buffers": True,
|
||||
"find_unused_parameters": True
|
||||
},
|
||||
"torch_non_blocking": False,
|
||||
"set_grad_to_none": True
|
||||
"non_blocking": False,
|
||||
}
|
||||
|
||||
)
|
||||
|
||||
assert trainer.set_grad_to_none is True
|
||||
assert trainer.driver.non_blocking is False
|
||||
assert trainer.driver.wo_auto_param_call is True
|
||||
assert trainer.driver.output_from_new_proc == "all"
|
||||
|
@ -139,7 +139,7 @@ class TestFdl:
|
||||
logger.set_stdout()
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
with Capturing() as out:
|
||||
dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False)
|
||||
dl = TorchDataLoader(ds, batch_size=1, prefetch_factor=3, shuffle=False)
|
||||
for idx, batch in enumerate(dl):
|
||||
assert len(batch['x'])==1
|
||||
assert batch['x'][0].tolist() == ds[idx]['x']
|
||||
@ -154,7 +154,7 @@ class TestFdl:
|
||||
logger.set_stdout()
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
with Capturing() as out:
|
||||
dl = TorchDataLoader(ds, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False)
|
||||
dl = TorchDataLoader(ds, batch_size=1, num_workers=0, prefetch_factor=2, generator=torch.Generator(), shuffle=False)
|
||||
for idx, batch in enumerate(dl):
|
||||
assert len(batch['x'])==1
|
||||
assert batch['x'][0].tolist() == ds[idx]['x']
|
||||
|
@ -661,7 +661,7 @@ class TestSaveLoad:
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
@ -771,7 +771,7 @@ class TestSaveLoad:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
|
@ -632,7 +632,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
@ -720,7 +720,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
|
@ -682,7 +682,7 @@ class TestSaveLoad:
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
@ -731,7 +731,7 @@ class TestSaveLoad:
|
||||
"""
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
path = "checkpoints/"
|
||||
|
||||
num_replicas = len(device)
|
||||
|
||||
@ -764,6 +764,7 @@ class TestSaveLoad:
|
||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
else:
|
||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
|
||||
dist.barrier() # 等待save成功
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
|
||||
@ -788,7 +789,7 @@ class TestSaveLoad:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
|
@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
@ -9,6 +11,7 @@ from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
@ -245,6 +248,9 @@ class TestTorchDriverFunctions:
|
||||
"""
|
||||
# 先确保不影响运行
|
||||
# TODO:正确性
|
||||
if parse_version(torch.__version__) < parse_version('1.7'):
|
||||
pytest.skip("Skip if torch version smaller than 1.6 since torch.manual_seed my cause bug:"
|
||||
"Overflow when unpacking long")
|
||||
TorchSingleDriver.worker_init_function(0)
|
||||
|
||||
@pytest.mark.torch
|
||||
@ -611,7 +617,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
@ -683,7 +689,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
|
@ -31,7 +31,7 @@ def _test(local_rank: int, world_size: int, device: "torch.device",
|
||||
|
||||
my_result = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
assert np.allclose(my_result[keys], metric_result[keys], atol=0.000001)
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@ -69,7 +69,6 @@ class TestClassfiyFPreRecMetric:
|
||||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207],
|
||||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207],
|
||||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]])
|
||||
arg_max_pred = torch.argmax(pred, dim=-1)
|
||||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
|
||||
0, 3, 0, 0, 0, 1, 3, 1])
|
||||
|
||||
@ -79,10 +78,9 @@ class TestClassfiyFPreRecMetric:
|
||||
f1_score = 0.1882051282051282
|
||||
recall = 0.1619047619047619
|
||||
pre = 0.23928571428571427
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(f_type='micro')
|
||||
metric.update(pred, target)
|
||||
@ -93,7 +91,7 @@ class TestClassfiyFPreRecMetric:
|
||||
|
||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall}
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)
|
||||
|
||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
|
||||
metric.update(pred, target)
|
||||
@ -103,19 +101,35 @@ class TestClassfiyFPreRecMetric:
|
||||
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5},
|
||||
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2},
|
||||
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9},
|
||||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9},
|
||||
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427,
|
||||
'recall': 0.1619047619047619, 'support': 32},
|
||||
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32},
|
||||
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875,
|
||||
'support': 32}}
|
||||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}}
|
||||
for keys in result_dict.keys():
|
||||
if keys == "f" or "pre" or "rec":
|
||||
continue
|
||||
gl = str(keys[-1])
|
||||
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"}
|
||||
gk = tmp_d[keys[0]]
|
||||
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
|
||||
assert np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001)
|
||||
|
||||
def test_seq_len(self):
|
||||
pred = torch.tensor([[[0.3, 0.7, 0.1], [0.4, 0.1, 0.1], [0.3, 0.1, 0.7]],
|
||||
[[0.7, 0.1, 0.1], [0.5, 0.9, 0.1], [0.3, 0.1, 0.7]]])
|
||||
seq_len = torch.LongTensor([3, 2])
|
||||
target = torch.LongTensor([[1, 0, 2], [0, 1, 0]])
|
||||
|
||||
# 不考虑长度
|
||||
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
|
||||
metric.update(pred, target)
|
||||
result_dict = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
assert result_dict[keys] != 1
|
||||
|
||||
# 考虑长度
|
||||
metric = ClassifyFPreRecMetric(only_gross=True, f_type='macro')
|
||||
metric.update(pred, target, seq_len=seq_len)
|
||||
result_dict = metric.get_metric()
|
||||
for keys in ['f', 'pre', 'rec']:
|
||||
assert result_dict[keys] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("f_type, f1_score,recall,pre",
|
||||
[('macro', 0.1882051282051282, 0.1619047619047619, 0.23928571428571427),
|
||||
@ -180,3 +194,22 @@ class TestClassfiyFPreRecMetric:
|
||||
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
def test_binary(self):
|
||||
pred = torch.randn(10, 2)
|
||||
target = torch.randint(1, size=(10,))
|
||||
metric = ClassifyFPreRecMetric()
|
||||
metric.update(pred, target)
|
||||
results = metric.get_metric()
|
||||
print(target)
|
||||
print(metric._tp, metric._fp, metric._fn)
|
||||
assert results['f']==results['rec']==results['pre']
|
||||
|
||||
pred = torch.randn(10, 2)
|
||||
target = torch.randint(2, size=(10,))
|
||||
metric = ClassifyFPreRecMetric()
|
||||
metric.update(pred, target)
|
||||
results = metric.get_metric()
|
||||
print(target)
|
||||
print(metric._tp, metric._fp, metric._fn)
|
||||
assert results['f']==results['rec']==results['pre']
|
||||
|
@ -226,7 +226,7 @@ class TestSpanFPreRecMetric:
|
||||
# print(expected_metric)
|
||||
metric_value = metric.get_metric()
|
||||
for key, value in expected_metric.items():
|
||||
np.allclose(value, metric_value[key])
|
||||
assert np.allclose(value, metric_value[key])
|
||||
|
||||
def test_auto_encoding_type_infer(self):
|
||||
# 检查是否可以自动check encode的类型
|
||||
|
Loading…
Reference in New Issue
Block a user