mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
a2956b697e
@ -11,7 +11,10 @@ __all__ = [
|
||||
'RichCallback',
|
||||
"LRSchedCallback",
|
||||
'LoadBestModelCallback',
|
||||
"EarlyStopCallback"
|
||||
"EarlyStopCallback",
|
||||
|
||||
"TorchWarmupCallback",
|
||||
"TorchGradClipCallback"
|
||||
]
|
||||
|
||||
|
||||
@ -23,4 +26,5 @@ from .progress_callback import choose_progress_callback, ProgressCallback, RichC
|
||||
from .lr_scheduler_callback import LRSchedCallback
|
||||
from .load_best_model_callback import LoadBestModelCallback
|
||||
from .early_stop_callback import EarlyStopCallback
|
||||
from .torch_callbacks import *
|
||||
|
||||
|
@ -1,16 +1,12 @@
|
||||
from typing import Union, Callable, Dict, Optional, Any
|
||||
from abc import ABC
|
||||
|
||||
__all__ = [
|
||||
'Callback',
|
||||
]
|
||||
|
||||
from typing import Union, Callable, Dict, Optional, Any
|
||||
|
||||
from .callback_events import Events, EventsList, Filter
|
||||
from .utils import _get_monitor_value
|
||||
from fastNLP.core.callbacks.callback_events import _SingleEventState
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
from fastNLP.core.utils.utils import _check_valid_parameters_number
|
||||
|
||||
|
||||
class Callback:
|
||||
@ -278,135 +274,3 @@ class _CallbackWrapper(Callback):
|
||||
@property
|
||||
def callback_name(self):
|
||||
return self.fn.__name__
|
||||
|
||||
|
||||
class CanItemDataType(ABC):
|
||||
"""
|
||||
检测可以进行传输的对象。
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||
if cls is CanItemDataType:
|
||||
item = getattr(subclass, 'item', None)
|
||||
return callable(item)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class HasMonitorCallback(Callback):
|
||||
def __init__(self, monitor, larger_better, must_have_monitor=False):
|
||||
self.set_monitor(monitor, larger_better)
|
||||
self.must_have_moinitor = must_have_monitor
|
||||
|
||||
def set_monitor(self, monitor, larger_better):
|
||||
if callable(monitor): # 检查是否能够接受一个参数
|
||||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
|
||||
self.monitor = monitor
|
||||
else:
|
||||
self.monitor = str(monitor) if monitor is not None else None
|
||||
self.larger_better = bool(larger_better)
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
self.monitor_value = float('inf')
|
||||
self._real_monitor = self.monitor
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
"""
|
||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
|
||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。
|
||||
|
||||
:param trainer:
|
||||
:param driver:
|
||||
:return:
|
||||
"""
|
||||
if self.monitor is None and trainer.monitor is not None:
|
||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
|
||||
if self.must_have_moinitor and self.monitor is None:
|
||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
|
||||
f"You can set it in the initialization or through Trainer.")
|
||||
|
||||
def get_monitor_value(self, results:Dict)->Union[float, None]:
|
||||
"""
|
||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。
|
||||
|
||||
:param results:
|
||||
:return: 如果为 None ,表明此次没有找到合适的monitor
|
||||
"""
|
||||
if len(results)==0:
|
||||
return None
|
||||
# 保证所有的 tensor 都被转换为了 python 特定的类型
|
||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
|
||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if monitor_value is None:
|
||||
return monitor_value
|
||||
# 第一次运行
|
||||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
|
||||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
|
||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
|
||||
# 检测到此次和上次不同。
|
||||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
|
||||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
|
||||
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
|
||||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
|
||||
f"customized monitor function when the evaluation results are varying between validation.")
|
||||
|
||||
self._real_monitor = use_monitor
|
||||
return monitor_value
|
||||
|
||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
|
||||
"""
|
||||
检测 monitor_value 是否是更好的
|
||||
|
||||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
|
||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
|
||||
:return:
|
||||
"""
|
||||
if monitor_value is None:
|
||||
return False
|
||||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
|
||||
if keep_if_better and better:
|
||||
self.monitor_value = monitor_value
|
||||
return better
|
||||
|
||||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
|
||||
"""
|
||||
传入的两个值中,是否monitor_value1的结果更好。
|
||||
|
||||
:param monitor_value1:
|
||||
:param monitor_value2:
|
||||
:return:
|
||||
"""
|
||||
if monitor_value1 is None and monitor_value2 is None:
|
||||
return True
|
||||
if monitor_value1 is None:
|
||||
return False
|
||||
if monitor_value2 is None:
|
||||
return True
|
||||
better = False
|
||||
if (self.larger_better and monitor_value1 > monitor_value2) or \
|
||||
(not self.larger_better and monitor_value1 < monitor_value2):
|
||||
better = True
|
||||
return better
|
||||
|
||||
@property
|
||||
def monitor_name(self):
|
||||
"""
|
||||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。
|
||||
|
||||
:return:
|
||||
"""
|
||||
if callable(self.monitor):
|
||||
try:
|
||||
monitor_name = self.monitor.__qualname__
|
||||
except:
|
||||
monitor_name = self.monitor.__name__
|
||||
elif self.monitor is None:
|
||||
return None
|
||||
else:
|
||||
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
|
||||
monitor_name = str(self.monitor)
|
||||
return monitor_name
|
||||
|
@ -10,9 +10,9 @@ from copy import deepcopy
|
||||
|
||||
|
||||
import fastNLP
|
||||
from .callback import HasMonitorCallback
|
||||
from .has_monitor_callback import HasMonitorCallback
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK
|
||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
|
||||
|
||||
|
||||
@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback):
|
||||
:return:
|
||||
"""
|
||||
folder = self.timestamp_path.joinpath(folder_name)
|
||||
synchronize_mkdir(folder)
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建
|
||||
synchronize_mkdir(folder)
|
||||
_fn = getattr(trainer, self.save_fn_name)
|
||||
_fn(
|
||||
folder=folder,
|
||||
|
@ -4,7 +4,7 @@ __all__ = [
|
||||
|
||||
from typing import Dict, Union, Callable
|
||||
|
||||
from .callback import HasMonitorCallback
|
||||
from .has_monitor_callback import HasMonitorCallback
|
||||
from fastNLP.core.utils.exceptions import EarlyStopException
|
||||
|
||||
|
||||
|
189
fastNLP/core/callbacks/has_monitor_callback.py
Normal file
189
fastNLP/core/callbacks/has_monitor_callback.py
Normal file
@ -0,0 +1,189 @@
|
||||
__all__ = [
|
||||
'HasMonitorCallback',
|
||||
'ExecuteOnceBetterMonitor'
|
||||
]
|
||||
|
||||
from typing import Dict, Union, Any
|
||||
from abc import ABC
|
||||
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
from fastNLP.core.callbacks import Callback
|
||||
from fastNLP.core.callbacks.utils import _get_monitor_value
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.utils.utils import _check_valid_parameters_number
|
||||
|
||||
|
||||
class CanItemDataType(ABC):
|
||||
"""
|
||||
检测可以进行传输的对象。
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||
if cls is CanItemDataType:
|
||||
item = getattr(subclass, 'item', None)
|
||||
return callable(item)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
||||
class HasMonitorCallback(Callback):
|
||||
def __init__(self, monitor, larger_better, must_have_monitor=False):
|
||||
"""
|
||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了
|
||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。
|
||||
|
||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
|
||||
果(字典类型),返回一个 float 值作为 monitor 的结果。
|
||||
:param larger_better: monitor 是否时越大越好
|
||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。
|
||||
"""
|
||||
self.set_monitor(monitor, larger_better)
|
||||
self.must_have_moinitor = must_have_monitor
|
||||
|
||||
def set_monitor(self, monitor, larger_better):
|
||||
if callable(monitor): # 检查是否能够接受一个参数
|
||||
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
|
||||
self.monitor = monitor
|
||||
else:
|
||||
self.monitor = str(monitor) if monitor is not None else None
|
||||
self.larger_better = bool(larger_better)
|
||||
if larger_better:
|
||||
self.monitor_value = float('-inf')
|
||||
else:
|
||||
self.monitor_value = float('inf')
|
||||
self._real_monitor = self.monitor
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
"""
|
||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。
|
||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。
|
||||
|
||||
:param trainer:
|
||||
:param driver:
|
||||
:return:
|
||||
"""
|
||||
if self.monitor is None and trainer.monitor is not None:
|
||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
|
||||
if self.must_have_moinitor and self.monitor is None:
|
||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
|
||||
f"You can set it in the initialization or through Trainer.")
|
||||
|
||||
def get_monitor_value(self, results:Dict)->Union[float, None]:
|
||||
"""
|
||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。
|
||||
|
||||
:param results:
|
||||
:return: 如果为 None ,表明此次没有找到合适的monitor
|
||||
"""
|
||||
if len(results)==0:
|
||||
return None
|
||||
# 保证所有的 tensor 都被转换为了 python 特定的类型
|
||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
|
||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
|
||||
real_monitor=self._real_monitor,
|
||||
res=results)
|
||||
if monitor_value is None:
|
||||
return monitor_value
|
||||
# 第一次运行
|
||||
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
|
||||
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
|
||||
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
|
||||
# 检测到此次和上次不同。
|
||||
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
|
||||
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
|
||||
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
|
||||
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
|
||||
f"customized monitor function when the evaluation results are varying between validation.")
|
||||
|
||||
self._real_monitor = use_monitor
|
||||
return monitor_value
|
||||
|
||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
|
||||
"""
|
||||
检测 monitor_value 是否是更好的
|
||||
|
||||
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
|
||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
|
||||
:return:
|
||||
"""
|
||||
if monitor_value is None:
|
||||
return False
|
||||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
|
||||
if keep_if_better and better:
|
||||
self.monitor_value = monitor_value
|
||||
return better
|
||||
|
||||
def is_better_results(self, results, keep_if_better=True):
|
||||
"""
|
||||
检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。
|
||||
|
||||
:param results: on_valid_ends() 接口中传入的 evaluation 结果。
|
||||
:param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。
|
||||
:return:
|
||||
"""
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if monitor_value is None:
|
||||
return False
|
||||
return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better)
|
||||
|
||||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
|
||||
"""
|
||||
传入的两个值中,是否monitor_value1的结果更好。
|
||||
|
||||
:param monitor_value1:
|
||||
:param monitor_value2:
|
||||
:return:
|
||||
"""
|
||||
if monitor_value1 is None and monitor_value2 is None:
|
||||
return True
|
||||
if monitor_value1 is None:
|
||||
return False
|
||||
if monitor_value2 is None:
|
||||
return True
|
||||
better = False
|
||||
if (self.larger_better and monitor_value1 > monitor_value2) or \
|
||||
(not self.larger_better and monitor_value1 < monitor_value2):
|
||||
better = True
|
||||
return better
|
||||
|
||||
@property
|
||||
def monitor_name(self):
|
||||
"""
|
||||
返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。
|
||||
|
||||
:return:
|
||||
"""
|
||||
if callable(self.monitor):
|
||||
try:
|
||||
monitor_name = self.monitor.__qualname__
|
||||
except:
|
||||
monitor_name = self.monitor.__name__
|
||||
elif self.monitor is None:
|
||||
return None
|
||||
else:
|
||||
# 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了
|
||||
monitor_name = str(self.monitor)
|
||||
return monitor_name
|
||||
|
||||
|
||||
class ExecuteOnceBetterMonitor(HasMonitorCallback):
|
||||
def __init__(self, monitor, larger_better, execute_fn):
|
||||
"""
|
||||
当监控的 monitor 结果更好的时候,调用 execute_fn 函数。
|
||||
|
||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结
|
||||
果(字典类型),返回一个 float 值作为 monitor 的结果。
|
||||
:param larger_better: monitor 是否时越大越好
|
||||
:param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。
|
||||
"""
|
||||
super().__init__(monitor, larger_better, must_have_monitor=True)
|
||||
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
|
||||
self.execute_fn = execute_fn()
|
||||
|
||||
def on_validate_end(self, trainer, results):
|
||||
if self.is_better_results(results):
|
||||
self.execute_fn()
|
@ -4,7 +4,7 @@ __all__ = [
|
||||
|
||||
import os
|
||||
from typing import Optional, Callable, Union
|
||||
from .callback import HasMonitorCallback
|
||||
from .has_monitor_callback import HasMonitorCallback
|
||||
from io import BytesIO
|
||||
import shutil
|
||||
|
||||
@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback):
|
||||
self.get_monitor_value(sanity_check_res)
|
||||
|
||||
def on_validate_end(self, trainer, results):
|
||||
monitor_value = self.get_monitor_value(results)
|
||||
if monitor_value is None:
|
||||
return
|
||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
|
||||
if self.is_better_results(results, keep_if_better=True):
|
||||
if self.real_save_folder:
|
||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
|
||||
model_save_fn=self.model_save_fn)
|
||||
|
@ -8,7 +8,7 @@ __all__ = [
|
||||
'RichCallback'
|
||||
]
|
||||
|
||||
from .callback import HasMonitorCallback
|
||||
from .has_monitor_callback import HasMonitorCallback
|
||||
from fastNLP.core.callbacks.utils import _get_monitor_value
|
||||
from fastNLP.core.utils import f_rich_progress
|
||||
from fastNLP.core.log import logger
|
||||
|
8
fastNLP/core/callbacks/torch_callbacks/__init__.py
Normal file
8
fastNLP/core/callbacks/torch_callbacks/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
__all__ = [
|
||||
'TorchWarmupCallback',
|
||||
'TorchGradClipCallback'
|
||||
]
|
||||
|
||||
|
||||
from .torch_lr_sched_callback import TorchWarmupCallback
|
||||
from .torch_grad_clip_callback import TorchGradClipCallback
|
@ -0,0 +1,52 @@
|
||||
__all__ = [
|
||||
'TorchGradClipCallback'
|
||||
]
|
||||
from ..callback import Callback
|
||||
|
||||
|
||||
class TorchGradClipCallback(Callback):
|
||||
def __init__(self, clip_value=1, clip_type='norm', parameters=None):
|
||||
r"""
|
||||
在每次 optimizer update 之前将 parameter 进行 clip
|
||||
|
||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
||||
:param str clip_type: 支持'norm', 'value'
|
||||
两种::
|
||||
|
||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||
|
||||
2 'value', 将gradient限制在[-clip_value, clip_value],
|
||||
小于-clip_value的gradient被赋值为-clip_value;
|
||||
大于clip_value的gradient被赋值为clip_value.
|
||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。
|
||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
from torch import nn
|
||||
if clip_type == 'norm':
|
||||
self.clip_fun = nn.utils.clip_grad_norm_
|
||||
elif clip_type == 'value':
|
||||
self.clip_fun = nn.utils.clip_grad_value_
|
||||
else:
|
||||
raise ValueError("Only supports `norm` or `value` right now.")
|
||||
if parameters is not None:
|
||||
self.parameters = list(parameters)
|
||||
else:
|
||||
self.parameters = None
|
||||
self.clip_value = clip_value
|
||||
|
||||
def on_after_trainer_initialized(self, trainer, driver):
|
||||
assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \
|
||||
f"related drivers for now."
|
||||
parameters = []
|
||||
for optimizer in trainer.driver.optimizers:
|
||||
for param_group in optimizer.param_groups:
|
||||
parameters.extend(param_group['params'])
|
||||
self.parameters = parameters
|
||||
assert len(self.parameters), "There is no parameters need to be clipped."
|
||||
|
||||
def on_before_optimizers_step(self, trainer, optimizers):
|
||||
for optimizer in trainer.driver.optimizers:
|
||||
trainer.driver.grad_scaler.unscale_(optimizer)
|
||||
self.clip_fun(self.parameters, self.clip_value)
|
@ -0,0 +1,58 @@
|
||||
__all__ = [
|
||||
'TorchWarmupCallback'
|
||||
]
|
||||
import math
|
||||
|
||||
from ..callback import Callback
|
||||
|
||||
|
||||
class TorchWarmupCallback(Callback):
|
||||
def __init__(self, warmup=0.1, schedule='constant'):
|
||||
r"""
|
||||
调整 learning rate 的 callback 。仅在实际发生参数更新的情况下
|
||||
|
||||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float,
|
||||
如0.1, 则前10%的step是按照schedule策略调整learning rate。
|
||||
:param str schedule: 以哪种方式调整。
|
||||
linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0;
|
||||
constant前warmup的step上升到指定learning rate,后面的step保持learning rate.
|
||||
"""
|
||||
super().__init__()
|
||||
self.warmup = max(warmup, 0.)
|
||||
|
||||
self.initial_lrs = [] # 存放param_group的learning rate
|
||||
if schedule == 'constant':
|
||||
self.get_lr = self._get_constant_lr
|
||||
elif schedule == 'linear':
|
||||
self.get_lr = self._get_linear_lr
|
||||
else:
|
||||
raise RuntimeError("Only support 'linear', 'constant'.")
|
||||
|
||||
def _get_constant_lr(self, progress):
|
||||
if progress <self.warmup:
|
||||
return progress /self.warmup
|
||||
return 1
|
||||
|
||||
def _get_linear_lr(self, progress):
|
||||
if progress <self.warmup:
|
||||
return progress /self.warmup
|
||||
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
||||
|
||||
def on_train_begin(self, trainer):
|
||||
self.t_steps = trainer.total_batches
|
||||
if self.warmup >1:
|
||||
self.warmup = self.warmup / self.t_steps
|
||||
self.t_steps = max(2, self.t_steps) # 不能小于2
|
||||
# 防止 t_steps 不能整除 accumulation_steps
|
||||
self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps
|
||||
# 获取param_group的初始learning rate
|
||||
for optimizer in trainer.driver.optimizers:
|
||||
for group in optimizer.param_groups:
|
||||
self.initial_lrs.append(group['lr'])
|
||||
|
||||
def on_before_optimizers_step(self, trainer, optimizers):
|
||||
# 这里需要加 accumulation_steps 是防止 lr 从 0 开始
|
||||
progress = (trainer.global_forward_batches + trainer.accumulation_steps) / self.t_steps
|
||||
for optimizer in trainer.driver.optimizers:
|
||||
for lr, group in zip(self.initial_lrs, optimizer.param_groups):
|
||||
group['lr'] = lr * self.get_lr(progress)
|
@ -219,10 +219,10 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
""" 设置内部的 Evaluator """
|
||||
if metrics is None and evaluate_dataloaders is not None:
|
||||
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")
|
||||
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")
|
||||
|
||||
if metrics is not None and evaluate_dataloaders is None:
|
||||
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")
|
||||
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")
|
||||
|
||||
self.evaluator = None
|
||||
self.monitor = monitor
|
||||
|
@ -129,7 +129,7 @@ class Driver(ABC):
|
||||
@property
|
||||
def optimizers(self) -> List:
|
||||
r"""
|
||||
如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimzer,我们会使用一个 List 将其
|
||||
如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimizer,我们会使用一个 List 将其
|
||||
包裹;
|
||||
|
||||
:return: List[optimizer0, optimizer1, optimizer2, ...]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Callable, Dict, Tuple
|
||||
|
||||
from .jittor_driver import JittorDriver
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver):
|
||||
return self._data_device
|
||||
return self.model_device
|
||||
|
||||
def train_step(self, batch):
|
||||
return self._train_step(batch)
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
pass
|
||||
|
||||
def validate_step(self, batch):
|
||||
return self._validate_step(batch)
|
||||
|
||||
def test_step(self, batch):
|
||||
return self._test_step(batch)
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
pass
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
|
@ -1,9 +1,11 @@
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Union, Tuple, Callable, Optional
|
||||
|
||||
from .jittor_driver import JittorDriver
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver):
|
||||
self.global_rank = 0
|
||||
self.world_size = 1
|
||||
|
||||
if hasattr(self.model, "train_step"):
|
||||
self._train_step = self.model.train_step
|
||||
self._train_signature_fn = None
|
||||
else:
|
||||
self._train_step = self.model
|
||||
model = self.unwrap_model()
|
||||
self._train_signature_fn = model.execute
|
||||
|
||||
if hasattr(self.model, "evaluate_step"):
|
||||
self._validate_step = self.model.evaluate_step
|
||||
self._validate_signature_fn = None
|
||||
elif hasattr(self.model, "test_step"):
|
||||
self._validate_step = self.model.test_step
|
||||
self._validate_signature_fn = self.model.test_step
|
||||
else:
|
||||
self._validate_step = self.model
|
||||
model = self.unwrap_model()
|
||||
self._validate_signature_fn = model.execute
|
||||
|
||||
if hasattr(self.model, "test_step"):
|
||||
self._test_step = self.model.test_step
|
||||
self._test_signature_fn = None
|
||||
elif hasattr(self.model, "evaluate_step"):
|
||||
self._test_step = self.model.evaluate_step
|
||||
self._test_signature_fn = self.model.evaluate_step
|
||||
else:
|
||||
self._test_step = self.model
|
||||
model = self.unwrap_model()
|
||||
self._test_signature_fn = model.execute
|
||||
|
||||
def train_step(self, batch) -> Dict:
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
|
||||
else:
|
||||
return self._train_step(batch)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
jittor optimizers 的step函数可以传入参数loss
|
||||
@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.zero_grad()
|
||||
|
||||
def validate_step(self, batch):
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return self._validate_step(batch)
|
||||
return fn(batch)
|
||||
|
||||
def test_step(self, batch):
|
||||
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
if hasattr(self.model, fn):
|
||||
fn = getattr(self.model, fn)
|
||||
if not callable(fn):
|
||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
|
||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
|
||||
return fn, None
|
||||
elif fn in {"train_step", "evaluate_step"}:
|
||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
|
||||
return self.model, self.model.forward
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
|
||||
|
||||
def unwrap_model(self):
|
||||
return self.model
|
||||
|
376
fastNLP/core/drivers/paddle_driver/dist_utils.py
Normal file
376
fastNLP/core/drivers/paddle_driver/dist_utils.py
Normal file
@ -0,0 +1,376 @@
|
||||
import io
|
||||
import pickle
|
||||
_pickler = pickle.Pickler
|
||||
_unpickler = pickle.Unpickler
|
||||
from typing import Any, List
|
||||
|
||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
try:
|
||||
from torch._C._distributed_c10d import ProcessGroupGloo
|
||||
from torch._C._distributed_c10d import _ProcessGroupWrapper
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
|
||||
|
||||
def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
if dst == my_rank:
|
||||
if not gather_list:
|
||||
raise ValueError(
|
||||
"Argument ``gather_list`` must be specified on destination rank."
|
||||
)
|
||||
elif gather_list:
|
||||
raise ValueError(
|
||||
"Argument ``gather_list`` must NOT be specified "
|
||||
"on non-destination ranks."
|
||||
)
|
||||
|
||||
|
||||
def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
|
||||
"""
|
||||
从其它 rank gather 东西到 dst rank 。
|
||||
|
||||
Gathers picklable objects from the whole group in a single process.
|
||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the
|
||||
object must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
obj (Any): Input object. Must be picklable.
|
||||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
|
||||
should be correctly sized as the size of the group for this
|
||||
collective and will contain the output. Must be ``None`` on non-dst
|
||||
ranks. (default is ``None``)
|
||||
dst (int, optional): Destination rank. (default is 0)
|
||||
group: (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. On the ``dst`` rank, ``object_gather_list`` will contain the
|
||||
output of the collective.
|
||||
|
||||
.. note:: Note that this API differs slightly from the gather collective
|
||||
since it does not provide an async_op handle and thus will be a blocking
|
||||
call.
|
||||
|
||||
.. note:: Note that this API is not supported when using the NCCL backend.
|
||||
|
||||
.. warning::
|
||||
:func:`gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.gather_object(
|
||||
gather_objects[dist.get_rank()],
|
||||
output if dist.get_rank() == 0 else None,
|
||||
dst=0
|
||||
)
|
||||
>>> # On rank 0
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
"""
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
|
||||
# Ensure object_gather_list is specified appopriately.
|
||||
my_rank = dist.get_rank()
|
||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
|
||||
# 防止 unpickle 的时候出现在了发送的 gpu 上。
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
input_tensor, local_size = _object_to_tensor(obj)
|
||||
group_backend = dist.get_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = group_backend == dist.Backend.NCCL
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device('cuda', torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = dist.get_world_size(group=group)
|
||||
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
|
||||
object_size_list = [
|
||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
|
||||
]
|
||||
# Allgather tensor sizes. An all-gather is needed here despite this being a
|
||||
# gather, since each rank needs to broadcast a tensor of the same (maximal)
|
||||
# size.
|
||||
dist.all_gather(object_size_list, local_size, group=group)
|
||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
|
||||
# Resize tensor to max size across all ranks.
|
||||
input_tensor.resize_(max_object_size)
|
||||
# Avoid populating output tensors if the result won't be gathered on this rank.
|
||||
if my_rank == dst:
|
||||
coalesced_output_tensor = torch.empty(
|
||||
max_object_size * group_size, dtype=torch.uint8, device=current_device
|
||||
)
|
||||
# Output tensors are nonoverlapping views of coalesced_output_tensor
|
||||
output_tensors = [
|
||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
||||
for i in range(group_size)
|
||||
]
|
||||
# All ranks call gather with equal-sized tensors.
|
||||
dist.gather(
|
||||
input_tensor,
|
||||
gather_list=output_tensors if my_rank == dst else None,
|
||||
dst=dst,
|
||||
group=group,
|
||||
)
|
||||
if my_rank != dst:
|
||||
return
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
|
||||
tensor_size = object_size_list[i]
|
||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device=None):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
byte_tensor = torch.ByteTensor(byte_storage)
|
||||
local_size = torch.LongTensor([byte_tensor.numel()])
|
||||
if device is not None:
|
||||
byte_tensor = byte_tensor.to(device)
|
||||
local_size = local_size.to(device)
|
||||
return byte_tensor, local_size
|
||||
|
||||
|
||||
def _tensor_to_object(tensor, tensor_size):
|
||||
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
|
||||
return _unpickler(io.BytesIO(buf)).load()
|
||||
|
||||
|
||||
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
|
||||
# src rank send to all other ranks
|
||||
size = torch.LongTensor([0]).to(device)
|
||||
|
||||
if cur_rank == src:
|
||||
world_size = dist.get_world_size(group=group)
|
||||
tensor, size = _object_to_tensor(obj)
|
||||
tensor = tensor.to(device)
|
||||
size = size.to(device)
|
||||
|
||||
# 首先同步 obj 的 size 的信息;
|
||||
dist.broadcast(size, src, group=group)
|
||||
for subrank in range(world_size):
|
||||
if subrank != src:
|
||||
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
|
||||
else:
|
||||
dist.broadcast(size, src, group=group)
|
||||
tensor = torch.ByteTensor([0] * size).to(device)
|
||||
dist.recv(tensor=tensor, src=src, group=group, tag=tag)
|
||||
|
||||
return _tensor_to_object(tensor.cpu(), size)
|
||||
|
||||
def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
|
||||
"""
|
||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||
|
||||
example:
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
'c': {
|
||||
'd': [1, 2]
|
||||
}
|
||||
}
|
||||
->
|
||||
[
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
|
||||
]
|
||||
|
||||
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
|
||||
序列化之后进行传输。
|
||||
:param device: 当前该参数无意义。
|
||||
:param group:
|
||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
|
||||
"""
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
if isinstance(obj, torch.Tensor):
|
||||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather(objs, obj, group=group)
|
||||
else:
|
||||
objs = [None for _ in range(dist.get_world_size(group))]
|
||||
# 防止 unpickle 的时候弄到发送的 gpu 上了
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
dist.all_gather_object(objs, obj, group=group)
|
||||
else:
|
||||
objs = all_gather_object(objs, obj, group=group)
|
||||
return objs
|
||||
|
||||
|
||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
|
||||
"""
|
||||
将 src 上的 obj 对象广播到其它 rank 上。
|
||||
|
||||
:param obj:
|
||||
:param src:
|
||||
:param device:
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
cur_rank = dist.get_rank(group)
|
||||
if cur_rank == src:
|
||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
if cur_rank!=src:
|
||||
get_obj = [None]
|
||||
dist.broadcast_object_list(get_obj, src=src, group=group)
|
||||
return get_obj[0]
|
||||
else:
|
||||
dist.broadcast_object_list([obj], src=src, group=group)
|
||||
return obj
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
if cur_rank == src:
|
||||
tensor, size = _object_to_tensor(obj, device=device)
|
||||
else:
|
||||
size = torch.LongTensor([0]).to(device)
|
||||
|
||||
dist.broadcast(size, src=src, group=group)
|
||||
if cur_rank != src:
|
||||
tensor = torch.empty(
|
||||
size.int().item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
device=device
|
||||
)
|
||||
dist.broadcast(tensor, src=src, group=group)
|
||||
|
||||
return _tensor_to_object(tensor, tensor_size=size.item())
|
||||
|
||||
|
||||
def _check_for_nccl_backend(group):
|
||||
pg = group or dist.distributed_c10d._get_default_group()
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, _ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
dist.is_nccl_available() and
|
||||
isinstance(pg, dist.ProcessGroupNCCL)
|
||||
)
|
||||
|
||||
|
||||
def all_gather_object(object_list, obj, group=None):
|
||||
"""
|
||||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。
|
||||
|
||||
Gathers picklable objects from the whole group into a list. Similar to
|
||||
:func:`all_gather`, but Python objects can be passed in. Note that the object
|
||||
must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
object_list (list[Any]): Output list. It should be correctly sized as the
|
||||
size of the group for this collective and will contain the output.
|
||||
object (Any): Pickable Python object to be broadcast from current process.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. If the calling rank is part of this group, the output of the
|
||||
collective will be populated into the input ``object_list``. If the
|
||||
calling rank is not part of the group, the passed in ``object_list`` will
|
||||
be unmodified.
|
||||
|
||||
.. note:: Note that this API differs slightly from the :func:`all_gather`
|
||||
collective since it does not provide an ``async_op`` handle and thus
|
||||
will be a blocking call.
|
||||
|
||||
.. note:: For NCCL-based processed groups, internal tensor representations
|
||||
of objects must be moved to the GPU device before communication takes
|
||||
place. In this case, the device used is given by
|
||||
``torch.cuda.current_device()`` and it is the user's responsiblity to
|
||||
ensure that this is set so that each rank has an individual GPU, via
|
||||
``torch.cuda.set_device()``.
|
||||
|
||||
.. warning::
|
||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
"""
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
if is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in docstring.
|
||||
# We cannot simply use my_rank since rank == device is not necessarily
|
||||
# true.
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
else:
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device)
|
||||
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = dist.get_world_size(group=group)
|
||||
object_sizes_tensor = torch.zeros(
|
||||
group_size, dtype=torch.long, device=current_device
|
||||
)
|
||||
object_size_list = [
|
||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
|
||||
]
|
||||
# Allgather tensor sizes
|
||||
dist.all_gather(object_size_list, local_size, group=group)
|
||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
|
||||
# Resize tensor to max size across all ranks.
|
||||
input_tensor.resize_(max_object_size)
|
||||
coalesced_output_tensor = torch.empty(
|
||||
max_object_size * group_size, dtype=torch.uint8, device=current_device
|
||||
)
|
||||
# Output tensors are nonoverlapping views of coalesced_output_tensor
|
||||
output_tensors = [
|
||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
||||
for i in range(group_size)
|
||||
]
|
||||
dist.all_gather(output_tensors, input_tensor, group=group)
|
||||
# Deserialize outputs back to object.
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
if tensor.device != torch.device("cpu"):
|
||||
tensor = tensor.cpu()
|
||||
tensor_size = object_size_list[i]
|
||||
object_list[i] = _tensor_to_object(tensor, tensor_size)
|
||||
return object_list
|
@ -1,13 +1,12 @@
|
||||
import os
|
||||
import shutil
|
||||
from functools import partial
|
||||
from typing import List, Union, Optional, Dict
|
||||
from typing import List, Union, Optional, Dict, Tuple, Callable
|
||||
|
||||
from .paddle_driver import PaddleDriver
|
||||
from .fleet_launcher import FleetLauncher
|
||||
from .utils import (
|
||||
_FleetWrappingModel,
|
||||
ForwardState,
|
||||
_MODE_PARAMETER,
|
||||
get_device_from_visible,
|
||||
reset_seed,
|
||||
replace_sampler,
|
||||
@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
__all__ = [
|
||||
"PaddleFleetDriver",
|
||||
]
|
||||
# if os.path.exists(self.gloo_rendezvous_dir):
|
||||
# shutil.rmtree(self.gloo_rendezvous_dir)
|
||||
|
||||
class PaddleFleetDriver(PaddleDriver):
|
||||
def __init__(
|
||||
self,
|
||||
@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
# 我们就直接将 model_device 置为 None;
|
||||
self._model_device = None
|
||||
|
||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
|
||||
if isinstance(batch, Dict) and not wo_auto_param_call:
|
||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return self._validate_step(batch)
|
||||
|
||||
model = model._layers
|
||||
if hasattr(model, "train_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `train_step` method, which we can not call actually, we will"
|
||||
" call `forward` function instead of `train_step` and you should note that.")
|
||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
|
||||
|
||||
if hasattr(model, "evaluate_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `evaluate_step` method, which we can not call actually, "
|
||||
"we will call `forward` function instead of `evaluate_step` and you should note that.")
|
||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
|
||||
|
||||
if hasattr(model, "test_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `test_step` method, which we can not call actually, we will"
|
||||
" call `forward` function instead of `test_step` and you should note that.")
|
||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
|
||||
|
||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
|
||||
self._data_device = kwargs.get("data_device", None)
|
||||
if self._data_device is not None:
|
||||
@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
|
||||
self.world_size = None
|
||||
self.global_rank = 0
|
||||
self._configured = False # 防止重复调用 configure_ddp() 函数使用
|
||||
self._has_setup = False # 防止重复调用 setup() 函数
|
||||
|
||||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
|
||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
|
||||
@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
|
||||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
|
||||
|
||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
|
||||
self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹;
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
在主进程拉起其它子进程,将主进程作为rank 0
|
||||
@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
dist.barrier()
|
||||
|
||||
def configure_fleet(self):
|
||||
if not self._configured and not isinstance(self.model, DataParallel):
|
||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
|
||||
self.model = DataParallel(
|
||||
_FleetWrappingModel(self.model),
|
||||
**self._fleet_kwargs
|
||||
)
|
||||
self._has_fleetwrapped = True
|
||||
|
||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
|
||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
|
||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)
|
||||
|
||||
self._configured = True
|
||||
def on_exception(self):
|
||||
if os.path.exists(self.gloo_rendezvous_dir):
|
||||
shutil.rmtree(self.gloo_rendezvous_dir)
|
||||
super().on_exception()
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
return self._data_device
|
||||
return self.model_device
|
||||
|
||||
def train_step(self, batch):
|
||||
return self._train_step(batch)
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
if self._has_fleetwrapped:
|
||||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
|
||||
wo_auto_param_call=self.wo_auto_param_call)
|
||||
else:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return fn(batch)
|
||||
|
||||
def validate_step(self, batch):
|
||||
return self._validate_step(batch)
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
model = self.unwrap_model()
|
||||
if self._has_fleetwrapped:
|
||||
if hasattr(model, fn):
|
||||
fn = getattr(model, fn)
|
||||
if not callable(fn):
|
||||
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
|
||||
return fn, None
|
||||
elif fn in {"train_step", "evaluate_step"}:
|
||||
return model, model.forward
|
||||
else:
|
||||
raise RuntimeError(f"There is no `{fn}` method in your model.")
|
||||
else:
|
||||
if hasattr(model, fn):
|
||||
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
|
||||
f"the `{fn}` method, which we can not call actually, we will"
|
||||
" call `forward` function instead of `train_step` and you should note that.")
|
||||
elif fn not in {"train_step", "evaluate_step"}:
|
||||
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
|
||||
"`DistributedDataParallel` model, which means that we will only call model.forward "
|
||||
"function when we are in forward propagation.")
|
||||
|
||||
def test_step(self, batch):
|
||||
return self._test_step(batch)
|
||||
return self.model, model.forward
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
else:
|
||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
|
||||
|
||||
def backward(self, loss):
|
||||
self.grad_scaler.scale(loss).backward()
|
||||
|
||||
def step(self):
|
||||
for optimizer in self.optimizers:
|
||||
self.grad_scaler.step(optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
||||
def is_global_zero(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):
|
||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
|
||||
"""
|
||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
|
||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。
|
||||
|
||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据
|
||||
:param int src: source 的 global rank 。
|
||||
:param int dst: target 的 global rank,可以是多个目标 rank
|
||||
:param group: 所属的 group
|
||||
:param kwargs:
|
||||
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
|
||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
||||
"""
|
||||
return
|
||||
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
|
||||
|
||||
def all_gather(self, obj, group) -> List:
|
||||
"""
|
||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
|
||||
pickle 进行序列化,接收到之后再反序列化。
|
||||
|
||||
example:
|
||||
obj = {
|
||||
'a': [1, 1],
|
||||
'b': [[1, 2], [1, 2]],
|
||||
'c': {
|
||||
'd': [1, 2]
|
||||
}
|
||||
}
|
||||
->
|
||||
[
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
|
||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
|
||||
]
|
||||
|
||||
:param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
return
|
||||
return fastnlp_paddle_all_gather(obj, group=group)
|
||||
|
@ -71,6 +71,14 @@ class PaddleDriver(Driver):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.clear_grad()
|
||||
|
||||
def backward(self, loss):
|
||||
self.grad_scaler.scale(loss).backward()
|
||||
|
||||
def step(self):
|
||||
for optimizer in self.optimizers:
|
||||
self.grad_scaler.step(optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
||||
@staticmethod
|
||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
|
||||
r"""
|
||||
@ -115,28 +123,6 @@ class PaddleDriver(Driver):
|
||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
def check_evaluator_mode(self, mode: str):
|
||||
r"""
|
||||
因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数;
|
||||
因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么
|
||||
我们应当提醒用户这一行为;
|
||||
"""
|
||||
model = self.unwrap_model()
|
||||
if mode == "validate":
|
||||
if not hasattr(model, "evaluate_step"):
|
||||
if hasattr(model, "test_step"):
|
||||
logger.warning(
|
||||
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
|
||||
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
|
||||
"'evaluate_step'.")
|
||||
|
||||
else:
|
||||
if not hasattr(model, "test_step"):
|
||||
if hasattr(model, "evaluate_step"):
|
||||
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
|
||||
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
|
||||
"'test_step'.")
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_numeric(tensor, reduce=None):
|
||||
r"""
|
||||
@ -258,20 +244,21 @@ class PaddleDriver(Driver):
|
||||
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
|
||||
sampler_states = sampler.state_dict()
|
||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
|
||||
# 会造成多余实际消耗的问题。
|
||||
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
|
||||
# 会造成多余实际消耗的问题。
|
||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
|
||||
if num_consumed_samples_array is not None:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
|
||||
else:
|
||||
try:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
|
||||
except: # 有可能 batch_size 为 None,就只有损失精度了
|
||||
pass
|
||||
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
|
||||
if isinstance(sampler, ReproducibleSampler):
|
||||
# 如果是 sampler 的话,需要计算出实际的 sample 数目
|
||||
try:
|
||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
|
||||
except: # 有可能 batch_size 为 None,就只有损失精度了
|
||||
num_consumed_batches = sampler_states['num_consumed_samples']
|
||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
|
||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
|
||||
states['sampler_states'] = sampler_states
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
|
||||
states["sampler_states"] = sampler_states
|
||||
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional, Dict, Union
|
||||
from typing import Optional, Dict, Union, Callable, Tuple
|
||||
|
||||
from .paddle_driver import PaddleDriver
|
||||
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
|
||||
@ -11,16 +11,19 @@ from fastNLP.core.utils import (
|
||||
get_paddle_device_id,
|
||||
paddle_move_data_to_device,
|
||||
)
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleBatchSampler,
|
||||
RandomBatchSampler,
|
||||
ReproducibleSampler,
|
||||
RandomSampler,
|
||||
re_instantiate_sampler,
|
||||
)
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
from paddle import DataParallel
|
||||
from paddle.fluid.reader import _DatasetKind
|
||||
|
||||
__all__ = [
|
||||
@ -28,109 +31,57 @@ __all__ = [
|
||||
]
|
||||
|
||||
class PaddleSingleDriver(PaddleDriver):
|
||||
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs):
|
||||
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs):
|
||||
if isinstance(model, DataParallel):
|
||||
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")
|
||||
|
||||
cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None)
|
||||
if cuda_visible_devices == "":
|
||||
device = "cpu"
|
||||
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
|
||||
"use `cpu` instead of `gpu` device.")
|
||||
|
||||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs)
|
||||
|
||||
if device is None:
|
||||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")
|
||||
|
||||
if device != "cpu":
|
||||
if isinstance(device, int):
|
||||
device_id = device
|
||||
else:
|
||||
device_id = get_paddle_device_id(device)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
|
||||
self.model_device = get_paddle_gpu_str(device)
|
||||
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
self.world_size = 1
|
||||
|
||||
if isinstance(model, paddle.DataParallel):
|
||||
# 注意这里的 unwrap_model 调用的是具体子类的方法;
|
||||
model = self.unwrap_model()
|
||||
if hasattr(model, "train_step"):
|
||||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
|
||||
"implements the `train_step` method, which we can not call actually, we will "
|
||||
" call `forward` function instead of `train_step` and you should note that.")
|
||||
self._train_step = self.model
|
||||
self._train_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "evaluate_step"):
|
||||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
|
||||
"implements the `evaluate_step` method, which we can not call actually, we "
|
||||
"will call `forward` function instead of `evaluate_step` and you should note that.")
|
||||
self._validate_step = self.model
|
||||
self._validate_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "test_step"):
|
||||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
|
||||
"implements the `test_step` method, which we can not call actually, we will "
|
||||
"call `forward` function instead of `test_step` and you should note that.")
|
||||
self._test_step = self.model
|
||||
self._test_signature_fn = model.forward
|
||||
else:
|
||||
if hasattr(self.model, "train_step"):
|
||||
self._train_step = self.model.train_step
|
||||
self._train_signature_fn = None
|
||||
else:
|
||||
self._train_step = self.model
|
||||
# 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的;
|
||||
model = self.unwrap_model()
|
||||
self._train_signature_fn = model.forward
|
||||
|
||||
if hasattr(self.model, "evaluate_step"):
|
||||
self._validate_step = self.model.evaluate_step
|
||||
self._validate_signature_fn = None
|
||||
elif hasattr(self.model, "test_step"):
|
||||
self._validate_step = self.model.test_step
|
||||
self._validate_signature_fn = self.model.test_step
|
||||
else:
|
||||
self._validate_step = self.model
|
||||
model = self.unwrap_model()
|
||||
self._validate_signature_fn = model.forward
|
||||
|
||||
if hasattr(self.model, "test_step"):
|
||||
self._test_step = self.model.test_step
|
||||
self._test_signature_fn = None
|
||||
elif hasattr(self.model, "evaluate_step"):
|
||||
self._test_step = self.model.evaluate_step
|
||||
self._test_signature_fn = self.model.evaluate_step
|
||||
else:
|
||||
self._test_step = self.model
|
||||
model = self.unwrap_model()
|
||||
self._test_signature_fn = model.forward
|
||||
|
||||
def setup(self):
|
||||
device = self.model_device
|
||||
if device != "cpu":
|
||||
device_id = get_paddle_device_id(device)
|
||||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
|
||||
device = get_device_from_visible(device, output_type=str)
|
||||
device = get_device_from_visible(device, output_type=str)
|
||||
paddle.device.set_device(device)
|
||||
self.model.to(device)
|
||||
|
||||
def train_step(self, batch) -> Dict:
|
||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return self._train_step(batch)
|
||||
return fn(batch)
|
||||
|
||||
def backward(self, loss):
|
||||
self.grad_scaler.scale(loss).backward()
|
||||
|
||||
def step(self):
|
||||
for optimizer in self.optimizers:
|
||||
self.grad_scaler.step(optimizer)
|
||||
self.grad_scaler.update()
|
||||
|
||||
def validate_step(self, batch) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
if hasattr(self.model, fn):
|
||||
fn = getattr(self.model, fn)
|
||||
if not callable(fn):
|
||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
|
||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
|
||||
return fn, None
|
||||
elif fn in {"train_step", "evaluate_step"}:
|
||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
|
||||
return self.model, self.model.forward
|
||||
else:
|
||||
return self._validate_step(batch)
|
||||
|
||||
def test_step(self, batch) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
|
||||
|
||||
def move_data_to_device(self, batch: 'paddle.Tensor'):
|
||||
r"""
|
||||
@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver):
|
||||
return replace_sampler(dataloader, sampler)
|
||||
|
||||
if reproducible:
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
if isinstance(args.sampler, paddle.io.RandomSampler):
|
||||
# 如果本来就是随机的,直接替换
|
||||
sampler = RandomSampler(args.sampler.data_source)
|
||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
return dataloader
|
||||
|
||||
|
@ -11,7 +11,6 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
|
||||
from fastNLP.core.samplers import RandomSampler
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -87,8 +86,6 @@ class ForwardState(IntEnum):
|
||||
TEST = 2
|
||||
PREDICT = 3
|
||||
|
||||
_MODE_PARAMETER = "forward_state"
|
||||
|
||||
class _FleetWrappingModel(Layer):
|
||||
"""
|
||||
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和
|
||||
@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer):
|
||||
super(_FleetWrappingModel, self).__init__()
|
||||
self.model = model
|
||||
|
||||
if isinstance(model, paddle.DataParallel):
|
||||
model = model._layers
|
||||
if hasattr(model, "train_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `train_step` method, which we can not call actually, we will"
|
||||
" call `forward` function instead of `train_step` and you should note that.")
|
||||
self._train_step = self.model
|
||||
self._train_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "evaluate_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `evaluate_step` method, which we can not call actually, "
|
||||
"we will call `forward` function instead of `evaluate_step` and you should note that.")
|
||||
self._validate_step = self.model
|
||||
self._validate_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "test_step"):
|
||||
logger.warning(
|
||||
"Notice your model is a `paddle.DataParallel` model. And your "
|
||||
"model also implements the `test_step` method, which we can not call actually, we will"
|
||||
" call `forward` function instead of `test_step` and you should note that.")
|
||||
self._test_step = self.model
|
||||
self._test_signature_fn = model.forward
|
||||
else:
|
||||
if hasattr(model, "train_step"):
|
||||
self._train_step = model.train_step
|
||||
self._train_signature_fn = None
|
||||
else:
|
||||
self._train_step = model
|
||||
self._train_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "evaluate_step"):
|
||||
self._validate_step = model.validate_step
|
||||
self._validate_signature_fn = None
|
||||
elif hasattr(model, "test_step"):
|
||||
self._validate_step = model.test_step
|
||||
self._validate_signature_fn = None
|
||||
else:
|
||||
self._validate_step = model
|
||||
self._validate_signature_fn = model.forward
|
||||
|
||||
if hasattr(model, "test_step"):
|
||||
self._test_step = model.test_step
|
||||
self._test_signature_fn = None
|
||||
elif hasattr(model, "evaluate_step"):
|
||||
self._test_step = model.validate_step
|
||||
self._test_signature_fn = None
|
||||
else:
|
||||
self._test_step = model
|
||||
self._test_signature_fn = model.forward
|
||||
|
||||
def forward(self, batch, **kwargs) -> Dict:
|
||||
|
||||
forward_state = kwargs.pop(_MODE_PARAMETER)
|
||||
fn = kwargs.pop("fastnlp_fn")
|
||||
signature_fn = kwargs.pop("fastnlp_signature_fn")
|
||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call")
|
||||
|
||||
if forward_state == ForwardState.TRAIN:
|
||||
if isinstance(batch, Dict) and not wo_auto_param_call:
|
||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
|
||||
else:
|
||||
return self._train_step(batch)
|
||||
elif forward_state == ForwardState.VALIDATE:
|
||||
if isinstance(batch, Dict) and not wo_auto_param_call:
|
||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
|
||||
else:
|
||||
return self._validate_step(batch)
|
||||
elif forward_state == ForwardState.TEST:
|
||||
if isinstance(batch, Dict) and not wo_auto_param_call:
|
||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
elif forward_state == ForwardState.PREDICT:
|
||||
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
|
||||
if isinstance(batch, Dict) and not wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
raise NotImplementedError("You should direct a concrete evaluate_fn.")
|
||||
return fn(batch)
|
||||
|
||||
class DummyGradScaler:
|
||||
"""
|
||||
|
@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
|
||||
# world_size 和 rank
|
||||
if FASTNLP_BACKEND_LAUNCH in os.environ:
|
||||
if device is not None:
|
||||
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
|
||||
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
|
||||
"up your script. And we will directly get the local device via "
|
||||
"`os.environ['LOCAL_RANK']`.")
|
||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
|
||||
|
@ -37,7 +37,12 @@ class TorchSingleDriver(TorchDriver):
|
||||
super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs)
|
||||
|
||||
if device is None:
|
||||
raise ValueError("Parameter `device` can not be None in `TorchSingleDriver`.")
|
||||
logger.debug("device is not set, fastNLP will try to automatically get it.")
|
||||
try:
|
||||
device = next(model.parameters()).device
|
||||
assert isinstance(device, torch.device)
|
||||
except:
|
||||
raise ValueError("fastNLP cannot get device automatically, please set device explicitly.")
|
||||
|
||||
self.model_device = device
|
||||
|
||||
@ -70,6 +75,7 @@ class TorchSingleDriver(TorchDriver):
|
||||
|
||||
return self.model, model.forward
|
||||
else:
|
||||
# TODO 这种直接调用模型某个接口的方法无法触发hook,也许需要做一个warning,如果用户有钩子,提醒他train_step无法触发。
|
||||
if hasattr(self.model, fn):
|
||||
fn = getattr(self.model, fn)
|
||||
if not callable(fn):
|
||||
|
@ -25,7 +25,7 @@ __all__ = [
|
||||
|
||||
from .utils import optimizer_state_to_device
|
||||
from fastNLP.core.drivers.driver import Driver
|
||||
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env
|
||||
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
|
||||
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
@ -224,6 +224,11 @@ class TorchDriver(Driver):
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
|
||||
# 4. 保存fp16的状态
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
grad_scaler_state_dict = self.grad_scaler.state_dict()
|
||||
states['grad_scaler_state_dict'] = grad_scaler_state_dict
|
||||
|
||||
logger.debug("Save optimizer state dict")
|
||||
states["optimizers_state_dict"] = optimizers_state_dict
|
||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
@ -232,7 +237,7 @@ class TorchDriver(Driver):
|
||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
optimizers_state_dict = states["optimizers_state_dict"]
|
||||
optimizers_state_dict = states.pop("optimizers_state_dict")
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"])
|
||||
@ -244,26 +249,37 @@ class TorchDriver(Driver):
|
||||
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu')
|
||||
if only_state_dict:
|
||||
model.load_state_dict(res)
|
||||
logger.debug("Load model state dict.")
|
||||
logger.debug("Load model state dict...")
|
||||
else:
|
||||
model.load_state_dict(res.state_dict())
|
||||
logger.debug("Load model.")
|
||||
logger.debug("Load model...")
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
# 3. 加载fp16的状态
|
||||
if 'grad_scaler_state_dict' in states:
|
||||
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
|
||||
if not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
|
||||
logger.debug("Load grad_scaler state dict...")
|
||||
elif not isinstance(self.grad_scaler, DummyGradScaler):
|
||||
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
|
||||
f"the training process may be unstable.")
|
||||
|
||||
# 4. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
|
||||
sampler = dataloader_args.sampler
|
||||
elif self.is_distributed():
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
|
||||
"`ReproducibleSampler`.")
|
||||
else:
|
||||
sampler = RandomBatchSampler(
|
||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
)
|
||||
sampler.load_state_dict(states['sampler_states'])
|
||||
sampler.load_state_dict(states.pop('sampler_states'))
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Optional, Dict, Union, Callable
|
||||
from typing import Optional, Dict, Union, Callable, Tuple
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver):
|
||||
elif self._data_device is not None:
|
||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
|
||||
|
||||
if hasattr(self.model, "train_step"):
|
||||
self._train_step = self.model.train_step
|
||||
self._train_signature_fn = None
|
||||
else:
|
||||
self._train_step = self.model
|
||||
self._train_signature_fn = self.model.forward
|
||||
|
||||
if hasattr(self.model, "evaluate_step"):
|
||||
self._validate_step = self.model.evaluate_step
|
||||
self._validate_signature_fn = None
|
||||
elif hasattr(self.model, "test_step"):
|
||||
self._validate_step = self.model.test_step
|
||||
self._validate_signature_fn = self.model.forward
|
||||
else:
|
||||
self._validate_step = self.model
|
||||
self._validate_signature_fn = self.model.forward
|
||||
|
||||
if hasattr(self.model, "test_step"):
|
||||
self._test_step = self.model.test_step
|
||||
self._test_signature_fn = None
|
||||
elif hasattr(self.model, "evaluate_step"):
|
||||
self._test_step = self.model.evaluate_step
|
||||
self._test_signature_fn = self.model.forward
|
||||
else:
|
||||
self._test_step = self.model
|
||||
self._test_signature_fn = self.model.forward
|
||||
|
||||
def setup(self):
|
||||
if self.model_device is not None:
|
||||
paddle.device.set_device(self.model_device.replace("cuda", "gpu"))
|
||||
@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver):
|
||||
f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
def train_step(self, batch) -> Dict:
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._train_step, batch)
|
||||
else:
|
||||
return self._train_step(batch)
|
||||
|
||||
def step(self):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.step()
|
||||
@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver):
|
||||
else:
|
||||
raise ValueError("Unknown optimizers type.")
|
||||
|
||||
def validate_step(self, batch):
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._validate_step, batch)
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
else:
|
||||
return self._validate_step(batch)
|
||||
return fn(batch)
|
||||
|
||||
def test_step(self, batch):
|
||||
if isinstance(batch, Dict):
|
||||
return auto_param_call(self._test_step, batch)
|
||||
def get_model_call_fn(self, fn: str) -> Tuple:
|
||||
if hasattr(self.model, fn):
|
||||
fn = getattr(self.model, fn)
|
||||
if not callable(fn):
|
||||
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
|
||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
|
||||
return fn, None
|
||||
elif fn in {"train_step", "evaluate_step"}:
|
||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
|
||||
return self.model, self.model.forward
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
|
||||
|
||||
def predict_step(self, batch):
|
||||
if isinstance(batch, Dict):
|
||||
|
@ -1,9 +1,4 @@
|
||||
__all__ = [
|
||||
'BucketSampler',
|
||||
'SortedSampler',
|
||||
'ConstTokenNumSampler',
|
||||
'ConstantTokenNumSampler',
|
||||
|
||||
'MixSampler',
|
||||
'DopedSampler',
|
||||
'MixSequentialSampler',
|
||||
@ -26,7 +21,6 @@ __all__ = [
|
||||
"re_instantiate_sampler"
|
||||
]
|
||||
|
||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
|
||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
|
||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
|
||||
|
@ -1,728 +0,0 @@
|
||||
r"""
|
||||
sampler 子类实现了 fastNLP 所需的各种采样器。
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"BucketSampler",
|
||||
"SortedSampler",
|
||||
'ConstTokenNumSampler',
|
||||
"ConstantTokenNumSampler",
|
||||
]
|
||||
|
||||
from itertools import chain
|
||||
from typing import List, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.utils.data import Sampler
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
|
||||
|
||||
# class DopedSampler(Sampler):
|
||||
# """
|
||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。
|
||||
# """
|
||||
#
|
||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
|
||||
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
|
||||
# ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, drop_last: bool = False) -> None:
|
||||
# if batch_size <= 0:
|
||||
# raise ValueError("batch_size should be a positive integer value, "
|
||||
# "but got batch_size={}".format(batch_size))
|
||||
# if not isinstance(drop_last, bool):
|
||||
# raise ValueError("drop_last should be a boolean value, but got "
|
||||
# "drop_last={}".format(drop_last))
|
||||
# self.batch_size = batch_size
|
||||
# self.drop_last = drop_last
|
||||
# self.ds_ratio = ds_ratio
|
||||
# if sampler is None:
|
||||
# if isinstance(dataset, List):
|
||||
# self.sampler = [SequentialSampler(ds) for ds in dataset]
|
||||
# elif isinstance(dataset, Dict):
|
||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
|
||||
#
|
||||
# elif isinstance(sampler, List):
|
||||
# if len(sampler) != len(dataset):
|
||||
# raise ValueError("the length of sampler != the length of sampler")
|
||||
# self.sampler = sampler
|
||||
# else:
|
||||
# self.sampler = sampler
|
||||
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
|
||||
# self.ds_ratio = ds_ratio
|
||||
# elif isinstance(ds_ratio, List):
|
||||
# if not all(item >= 0 for item in ds_ratio):
|
||||
# raise ValueError("batch_size should be a positive integer value, "
|
||||
# "but got batch_size={}".format(ds_ratio))
|
||||
# self.ds_ratio = ds_ratio
|
||||
# else:
|
||||
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
|
||||
#
|
||||
# def __iter__(self):
|
||||
# samplers, index = [], 0
|
||||
# if isinstance(self.sampler, List):
|
||||
# for idx, sampler in enumerate(self.sampler):
|
||||
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
|
||||
# index += len(sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for name, sampler in self.sampler.items():
|
||||
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
|
||||
# index += len(sampler)
|
||||
#
|
||||
# def __len__(self):
|
||||
# lens = 0
|
||||
# max_len, ds_len = 0, 0
|
||||
# if self.ds_ratio == 'truncate_to_least':
|
||||
# if isinstance(self.sampler, List):
|
||||
# max_len = min(len(sampler) for sampler in self.sampler)
|
||||
# ds_len = len(self.sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
|
||||
# for _, _ in self.sampler.items():
|
||||
# ds_len += 1
|
||||
#
|
||||
# elif self.ds_ratio == 'pad_to_most':
|
||||
# if isinstance(self.sampler, List):
|
||||
# max_len = max(len(sampler) for sampler in self.sampler)
|
||||
# ds_len = len(self.sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
|
||||
# for _, _ in self.sampler.items():
|
||||
# ds_len += 1
|
||||
#
|
||||
# if self.ds_ratio is None:
|
||||
# if isinstance(self.sampler, List):
|
||||
# for i in range(len(self.sampler)):
|
||||
# sampler = self.sampler[i]
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for name, sampler in self.sampler.items():
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# elif self.ds_ratio == 'truncate_to_least' or self.ds_ratio == 'pad_to_most':
|
||||
# for i in range(ds_len):
|
||||
# if self.drop_last:
|
||||
# lens += max_len // self.batch_size
|
||||
# else:
|
||||
# lens += (max_len + self.batch_size - 1) // self.batch_size
|
||||
# return lens
|
||||
#
|
||||
# def demo(self):
|
||||
# indexes = np.array([0]*self.batch_size + [1]*self.batch_size + [2]*self.batch_size)
|
||||
# shift = np.array([0]*self.batch_size + [len(ds1)]*self.batch_size + [len(ds1)+len(ds2)]*self.batch_size)
|
||||
# buffer = np.zeros(self.batch_size*self.num_ds, dtype=int)
|
||||
# select_sampler = np.random.randint(0, self.batch_size*self.num_ds, num_sample=self.batch_size)
|
||||
# select_indices = buffer[select_sampler] + shift[select_sampler]
|
||||
# num_1 = (indexes[select_sampler]==0).sum()
|
||||
#
|
||||
|
||||
|
||||
# class MixSequentialSampler(Sampler):
|
||||
# """
|
||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表顺序采样并返回index,只有处理了上一个dataset才会处理下一个。
|
||||
# """
|
||||
#
|
||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
|
||||
# sampler: Union[List[Sampler], Dict[str, Sampler], None] = None,
|
||||
# drop_last: bool = False) -> None:
|
||||
# """
|
||||
#
|
||||
# :param dataset: 实现了__getitem__和__len__的数据容器列表
|
||||
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset
|
||||
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
|
||||
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
|
||||
# """
|
||||
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
|
||||
# if isinstance(dataset, Dict) and isinstance(sampler, List):
|
||||
# raise ValueError(f"{sampler} must be dict")
|
||||
#
|
||||
# # 判断batch_size是否大于等于0
|
||||
# if batch_size <= 0:
|
||||
# raise ValueError("batch_size should be a positive integer value, "
|
||||
# "but got batch_size={}".format(batch_size))
|
||||
#
|
||||
# if not isinstance(drop_last, bool):
|
||||
# raise ValueError("drop_last should be a boolean value, but got "
|
||||
# "drop_last={}".format(drop_last))
|
||||
# self.batch_size = batch_size
|
||||
# self.drop_last = drop_last
|
||||
# if sampler is None:
|
||||
# if isinstance(dataset, List):
|
||||
# self.sampler = [SequentialSampler(ds) for ds in dataset]
|
||||
# elif isinstance(dataset, Dict):
|
||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
|
||||
# elif isinstance(sampler, List):
|
||||
# if len(sampler) != len(dataset):
|
||||
# raise ValueError("the length of sampler != the length of sampler")
|
||||
# self.sampler = sampler
|
||||
#
|
||||
# def __iter__(self) -> Iterable[List[int]]:
|
||||
# """
|
||||
# 按照dataset的顺序采样,打包成一个batch后返回
|
||||
# :return:
|
||||
# """
|
||||
# index = 0
|
||||
# batch = []
|
||||
# if isinstance(self. sampler, List):
|
||||
# for i in range(len(self.sampler)):
|
||||
# sampler = self.sampler[i]
|
||||
# for idx in sampler:
|
||||
# batch.append(idx + index)
|
||||
# if len(batch) == self.batch_size:
|
||||
# yield batch
|
||||
# batch = []
|
||||
# if len(batch) > 0 and not self.drop_last:
|
||||
# yield batch
|
||||
# batch = []
|
||||
# index += len(sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for name, sampler in self.sampler.items():
|
||||
# for idx in sampler:
|
||||
# batch.append(idx + index)
|
||||
# if len(batch) == self.batch_size:
|
||||
# yield batch
|
||||
# batch = []
|
||||
# if len(batch) > 0 and not self.drop_last:
|
||||
# yield batch
|
||||
# batch = []
|
||||
# index += len(sampler)
|
||||
#
|
||||
# def __len__(self) -> int:
|
||||
# lens = 0
|
||||
# if isinstance(self.sampler, List):
|
||||
# for i in range(len(self.sampler)):
|
||||
# sampler = self.sampler[i]
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for _, sampler in self.sampler.items():
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# return lens
|
||||
|
||||
|
||||
# class PollingSampler(Sampler):
|
||||
# """
|
||||
# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表轮流采样并返回index,处理了上个dataset的一个batch后会处理下一个。
|
||||
# """
|
||||
#
|
||||
# def __init__(self, dataset: Union[List, Dict], batch_size: int = 16,
|
||||
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
|
||||
# drop_last: bool = False, ds_ratio="pad_to_most") -> None:
|
||||
# """
|
||||
#
|
||||
# :param dataset: 实现了__getitem__和__len__的数据容器列表
|
||||
# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset
|
||||
# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象
|
||||
# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size
|
||||
# :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时,
|
||||
# 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样
|
||||
# """
|
||||
# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
|
||||
# if isinstance(dataset, Dict) and isinstance(sampler, List):
|
||||
# raise ValueError(f"{sampler} must be dict")
|
||||
# if isinstance(dataset, List) and isinstance(sampler, Dict):
|
||||
# raise ValueError(f"{sampler} must be list")
|
||||
# # 判断batch_size是否大于等于0
|
||||
# if batch_size <= 0:
|
||||
# raise ValueError("batch_size should be a positive integer value, "
|
||||
# "but got batch_size={}".format(batch_size))
|
||||
#
|
||||
# if not isinstance(drop_last, bool):
|
||||
# raise ValueError("drop_last should be a boolean value, but got "
|
||||
# "drop_last={}".format(drop_last))
|
||||
#
|
||||
# self.batch_size = batch_size
|
||||
# self.drop_last = drop_last
|
||||
# if sampler is None:
|
||||
# if isinstance(dataset, List):
|
||||
# self.sampler = [SequentialSampler(ds) for ds in dataset]
|
||||
# elif isinstance(dataset, Dict):
|
||||
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
|
||||
#
|
||||
# elif isinstance(sampler, List):
|
||||
# if len(sampler) != len(dataset):
|
||||
# raise ValueError("the length of sampler != the length of sampler")
|
||||
# self.sampler = sampler
|
||||
# else:
|
||||
# self.sampler = sampler
|
||||
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
|
||||
# self.ds_ratio = ds_ratio
|
||||
# else:
|
||||
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
|
||||
#
|
||||
# def __iter__(self) -> Iterable[List[int]]:
|
||||
# # index是数据集下标基址, pointer指向数据集列表的某个数据集
|
||||
# index, pointer, samplers, flag = 0, 0, [], False
|
||||
#
|
||||
# if isinstance(self.sampler, List):
|
||||
# for idx, sampler in enumerate(self.sampler):
|
||||
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
|
||||
# index += len(sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for name, sampler in self.sampler.items():
|
||||
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
|
||||
# index += len(sampler)
|
||||
# if self.ds_ratio == 'pad_to_most':
|
||||
# if isinstance(self.sampler, List):
|
||||
# limit_len = max(len(ds) for ds in self.sampler)
|
||||
# else:
|
||||
# limit_len = max(len(ds) for _, ds in self.sampler.items())
|
||||
# elif self.ds_ratio == 'truncate_to_least':
|
||||
# if isinstance(self.sampler, List):
|
||||
# limit_len = min(len(ds) for ds in self.sampler)
|
||||
# else:
|
||||
# limit_len = min(len(ds) for _, ds in self.sampler.items())
|
||||
# else:
|
||||
# limit_len = 0
|
||||
# # 最后一个批次的大小
|
||||
# last_batch_size = limit_len % self.batch_size
|
||||
#
|
||||
# while True:
|
||||
# # 全部采样完,退出
|
||||
# if len(samplers) == 0:
|
||||
# break
|
||||
# batch, flag = [], False
|
||||
# # sampler_len代表已经取出来的数据个数
|
||||
# sampler, batch_size, index, sampler_len, name = samplers.pop(0)
|
||||
# for _ in range(batch_size):
|
||||
# try:
|
||||
# batch.append(index + next(sampler))
|
||||
# sampler_len += 1
|
||||
# except StopIteration:
|
||||
# flag = True
|
||||
# # ds_ratio为None,第一种情况,删除掉采样完的数据即可。
|
||||
# if self.ds_ratio == 'pad_to_most' and sampler_len < limit_len:
|
||||
# # 重置sampler,并取足一个batch数据
|
||||
# sampler = iter(self.sampler[name])
|
||||
# # 由于batch_size一定小于等于ds的长度,故能够取足一个batch_size的数据
|
||||
# for _ in range(batch_size-len(batch)):
|
||||
# batch.append(next(sampler) + index)
|
||||
# sampler_len += 1
|
||||
# break
|
||||
#
|
||||
# # ds_ratio不为None情况
|
||||
# # 两种情况会触发一下逻辑:1.truncate_to_least时,最短的数据集最后一个batch大小不等于batch_size时,
|
||||
# # 其他较长的数据集的最后一个batch长度会较长;2. pad_to_most,最长的数据集最后一个batch不等于batch_size时,较短数据集最后一个
|
||||
# # batch长度会较长
|
||||
# if limit_len != 0 and limit_len < sampler_len:
|
||||
# batch = batch[:last_batch_size]
|
||||
# # ds_ratio为任意情况下, 没有取完所有数据,则添加到队列尾部
|
||||
# elif (limit_len == 0 and flag == False) or limit_len > sampler_len:
|
||||
# samplers.append((sampler, batch_size, index, sampler_len, name))
|
||||
# if len(batch) == batch_size:
|
||||
# yield batch
|
||||
# elif len(batch) > 0 and not self.drop_last:
|
||||
# yield batch
|
||||
#
|
||||
# def __len__(self) -> int:
|
||||
# lens = 0
|
||||
# max_len, ds_len = 0, 0
|
||||
# if self.ds_ratio == 'truncate_to_least':
|
||||
# if isinstance(self.sampler, List):
|
||||
# max_len = min(len(sampler) for sampler in self.sampler)
|
||||
# ds_len = len(self.sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
|
||||
# for _, _ in self.sampler.items():
|
||||
# ds_len += 1
|
||||
#
|
||||
# elif self.ds_ratio == 'pad_to_most':
|
||||
# if isinstance(self.sampler, List):
|
||||
# max_len = max(len(sampler) for sampler in self.sampler)
|
||||
# ds_len = len(self.sampler)
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
|
||||
# for _, _ in self.sampler.items():
|
||||
# ds_len += 1
|
||||
# if self.ds_ratio is None:
|
||||
# if isinstance(self.sampler, List):
|
||||
# for i in range(len(self.sampler)):
|
||||
# sampler = self.sampler[i]
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# elif isinstance(self.sampler, Dict):
|
||||
# for name, sampler in self.sampler.items():
|
||||
# if self.drop_last:
|
||||
# lens += len(sampler) // self.batch_size
|
||||
# else:
|
||||
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
|
||||
# else:
|
||||
# for i in range(ds_len):
|
||||
# if self.drop_last:
|
||||
# lens += max_len // self.batch_size
|
||||
# else:
|
||||
# lens += (max_len + self.batch_size - 1) // self.batch_size
|
||||
# return lens
|
||||
|
||||
|
||||
class BucketSampler(Sampler):
|
||||
r"""
|
||||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_buckets=10, batch_size=None, seq_len_field_name='seq_len', drop_last=False) -> None:
|
||||
r"""
|
||||
|
||||
:param int num_buckets: bucket的数量
|
||||
:param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非
|
||||
Trainer/Tester场景使用,需要显示传递该值
|
||||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.num_buckets = num_buckets
|
||||
self.batch_size = batch_size
|
||||
self.seq_len_field_name = seq_len_field_name
|
||||
|
||||
def set_batch_size(self, batch_size) -> None:
|
||||
r"""
|
||||
|
||||
:param int batch_size: 每个batch的大小
|
||||
:return:
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
if self.batch_size is None:
|
||||
raise RuntimeError("batch_size is None.")
|
||||
seq_lens = self.dataset.get_all_fields()[self.seq_len_field_name].content
|
||||
total_sample_num = len(seq_lens)
|
||||
|
||||
bucket_indexes = []
|
||||
assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets."
|
||||
num_sample_per_bucket = total_sample_num // self.num_buckets
|
||||
for i in range(self.num_buckets):
|
||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
|
||||
bucket_indexes[-1][1] = total_sample_num
|
||||
|
||||
sorted_seq_lens = list(sorted([(idx, seq_len) for
|
||||
idx, seq_len in zip(range(total_sample_num), seq_lens)],
|
||||
key=lambda x: x[1]))
|
||||
|
||||
batchs = []
|
||||
|
||||
left_init_indexes = []
|
||||
for b_idx in range(self.num_buckets):
|
||||
start_idx = bucket_indexes[b_idx][0]
|
||||
end_idx = bucket_indexes[b_idx][1]
|
||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
|
||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
|
||||
num_batch_per_bucket = len(left_init_indexes) // self.batch_size
|
||||
np.random.shuffle(left_init_indexes)
|
||||
for i in range(num_batch_per_bucket):
|
||||
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
|
||||
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
|
||||
if (left_init_indexes) != 0:
|
||||
batchs.append(left_init_indexes)
|
||||
np.random.shuffle(batchs)
|
||||
|
||||
return chain(*batchs)
|
||||
|
||||
|
||||
class ConstTokenNumSampler(Sampler):
|
||||
"""
|
||||
尽量保证每个batch的输入token数量是接近的。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, seq_len_field_name: List[int], max_token: int = 4096, max_sentence: int = -1,
|
||||
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
|
||||
"""
|
||||
|
||||
:param dataset:
|
||||
:param List[int] seq_len_field_name: 哪个field指示的sample的长度
|
||||
:param int max_token: 每个batch的最大的token数量
|
||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
|
||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
|
||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
|
||||
"""
|
||||
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
|
||||
self.dataset = dataset
|
||||
self.seq_len_field_name = seq_len_field_name
|
||||
self.num_bucket = num_bucket
|
||||
self.max_token = max_token
|
||||
self._max_sentence = max_sentence
|
||||
self.need_be_multiple_of = need_be_multiple_of
|
||||
|
||||
assert len(self.dataset) > self.num_bucket, "The number of samples should be larger than buckets."
|
||||
seq_len = self.dataset.get_field(self.seq_len_field_name)
|
||||
self.seq_len = seq_len
|
||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
|
||||
seq_len_indice.sort(key=lambda x: x[0])
|
||||
indice_in_buckets = []
|
||||
if self.num_bucket > 0:
|
||||
sample_per_bucket = len(seq_len_indice) // self.num_bucket
|
||||
i = 0
|
||||
while len(indice_in_buckets) < len(seq_len_indice):
|
||||
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
|
||||
i += 1
|
||||
else:
|
||||
indice_in_buckets = [seq_len_indice]
|
||||
self.indice_in_buckets = indice_in_buckets
|
||||
self.get_new_order()
|
||||
|
||||
@property
|
||||
def max_sentence(self):
|
||||
if self._max_sentence < 1:
|
||||
return 100000000
|
||||
return self._max_sentence
|
||||
|
||||
@max_sentence.setter
|
||||
def max_sentence(self, max_sentence):
|
||||
self._max_sentence = max_sentence
|
||||
|
||||
def get_new_order(self) -> None:
|
||||
np.random.shuffle(self.indice_in_buckets)
|
||||
for bucket in self.indice_in_buckets:
|
||||
np.random.shuffle(bucket)
|
||||
indices = list(chain(*self.indice_in_buckets))
|
||||
batches = []
|
||||
cur_max_len = 0
|
||||
batch = []
|
||||
for length, i in indices:
|
||||
max_len = max(length, cur_max_len)
|
||||
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
|
||||
left_sample = len(batch) % self.need_be_multiple_of
|
||||
add_samples = batch.copy()
|
||||
cur_max_len = length
|
||||
if left_sample != 0:
|
||||
add_samples = add_samples[:-left_sample]
|
||||
batch = batch[-left_sample:]
|
||||
cur_max_len = max(cur_max_len, max(batch))
|
||||
else:
|
||||
batch = []
|
||||
if len(add_samples) == 0:
|
||||
raise RuntimeError(
|
||||
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
|
||||
batches.append(add_samples)
|
||||
else:
|
||||
cur_max_len = max_len
|
||||
batch.append(i)
|
||||
if batch:
|
||||
left_sample = len(batch) % self.need_be_multiple_of
|
||||
add_samples = batch.copy()
|
||||
if left_sample != 0:
|
||||
add_samples = add_samples[:-left_sample].copy()
|
||||
if add_samples:
|
||||
batches.append(add_samples)
|
||||
np.random.shuffle(batches)
|
||||
self.batches = batches
|
||||
|
||||
def __iter__(self) -> Iterable[int]:
|
||||
for batch in self.batches:
|
||||
yield batch
|
||||
self.get_new_order()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
|
||||
class ConstantTokenNumSampler:
|
||||
"""
|
||||
尽量保证每个batch的输入token数量是接近的。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len, max_token: List[int] = 4096, max_sentence: int = -1,
|
||||
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
|
||||
"""
|
||||
|
||||
:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入
|
||||
:param int max_token: 每个batch的最大的token数量
|
||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
|
||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
|
||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
|
||||
"""
|
||||
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
|
||||
assert len(seq_len) > num_bucket, "The number of samples should be larger than buckets."
|
||||
self.seq_len = seq_len
|
||||
self.max_token = max_token
|
||||
self._max_sentence = max_sentence
|
||||
self.need_be_multiple_of = need_be_multiple_of
|
||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
|
||||
seq_len_indice.sort(key=lambda x: x[0])
|
||||
indice_in_buckets = []
|
||||
if num_bucket > 0:
|
||||
sample_per_bucket = len(seq_len_indice) // num_bucket
|
||||
i = 0
|
||||
while len(indice_in_buckets) < len(seq_len_indice):
|
||||
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
|
||||
i += 1
|
||||
else:
|
||||
indice_in_buckets = [seq_len_indice]
|
||||
self.indice_in_buckets = indice_in_buckets
|
||||
self.get_new_order()
|
||||
|
||||
@property
|
||||
def max_sentence(self):
|
||||
if self._max_sentence < 1:
|
||||
return 100000000
|
||||
return self._max_sentence
|
||||
|
||||
@max_sentence.setter
|
||||
def max_sentence(self, max_sentence):
|
||||
self._max_sentence = max_sentence
|
||||
|
||||
def get_new_order(self) -> None:
|
||||
np.random.shuffle(self.indice_in_buckets)
|
||||
for bucket in self.indice_in_buckets:
|
||||
np.random.shuffle(bucket)
|
||||
indices = list(chain(*self.indice_in_buckets))
|
||||
batches = []
|
||||
cur_max_len = 0
|
||||
batch = []
|
||||
for length, i in indices:
|
||||
max_len = max(length, cur_max_len)
|
||||
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
|
||||
left_sample = len(batch) % self.need_be_multiple_of
|
||||
add_samples = batch.copy()
|
||||
cur_max_len = length
|
||||
if left_sample != 0:
|
||||
add_samples = add_samples[:-left_sample]
|
||||
batch = batch[-left_sample:]
|
||||
cur_max_len = max(cur_max_len, max(batch))
|
||||
else:
|
||||
batch = []
|
||||
if len(add_samples) == 0:
|
||||
raise RuntimeError(
|
||||
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
|
||||
batches.append(add_samples)
|
||||
else:
|
||||
cur_max_len = max_len
|
||||
batch.append(i)
|
||||
if batch:
|
||||
left_sample = len(batch) % self.need_be_multiple_of
|
||||
add_samples = batch.copy()
|
||||
if left_sample != 0:
|
||||
add_samples = add_samples[:-left_sample].copy()
|
||||
if add_samples:
|
||||
batches.append(add_samples)
|
||||
np.random.shuffle(batches)
|
||||
self.batches = batches
|
||||
|
||||
def __iter__(self) -> Iterable[int]:
|
||||
for batch in self.batches:
|
||||
yield batch
|
||||
self.get_new_order()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
||||
|
||||
class SortedSampler(Sampler):
|
||||
r"""
|
||||
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, seq_len_field_name: str = 'seq_len', descending: bool = True) -> None:
|
||||
"""
|
||||
|
||||
:param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是
|
||||
数字,则使用该field的长度进行排序
|
||||
:param bool descending: 是否降序排列
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.seq_len_field_name = seq_len_field_name
|
||||
self.descending = descending
|
||||
|
||||
def __iter__(self) -> Iterable[int]:
|
||||
seq_lens = self.dataset.get_field(self.seq_len_field_name).content
|
||||
try:
|
||||
seq_lens = list(map(len, seq_lens))
|
||||
except:
|
||||
pass
|
||||
|
||||
orders = np.argsort(seq_lens).tolist() # 从小到大的顺序
|
||||
if self.descending:
|
||||
orders = orders[::-1]
|
||||
for order in orders:
|
||||
yield order
|
||||
|
||||
|
||||
def simple_sort_bucketing(lengths):
|
||||
r"""
|
||||
|
||||
:param lengths: list of int, the lengths of all examples.
|
||||
:return data: 2-level list
|
||||
::
|
||||
|
||||
[
|
||||
[index_11, index_12, ...], # bucket 1
|
||||
[index_21, index_22, ...], # bucket 2
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)]
|
||||
sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1])
|
||||
# TODO: need to return buckets
|
||||
return [idx for idx, _ in sorted_lengths]
|
||||
|
||||
|
||||
def k_means_1d(x, k, max_iter=100):
|
||||
r"""Perform k-means on 1-D data.
|
||||
|
||||
:param x: list of int, representing points in 1-D.
|
||||
:param k: the number of clusters required.
|
||||
:param max_iter: maximum iteration
|
||||
:return centroids: numpy array, centroids of the k clusters
|
||||
assignment: numpy array, 1-D, the bucket id assigned to each example.
|
||||
"""
|
||||
sorted_x = sorted(list(set(x)))
|
||||
x = np.array(x)
|
||||
if len(sorted_x) < k:
|
||||
raise ValueError("too few buckets")
|
||||
gap = len(sorted_x) / k
|
||||
|
||||
centroids = np.array([sorted_x[int(x * gap)] for x in range(k)])
|
||||
assign = None
|
||||
|
||||
for i in range(max_iter):
|
||||
# Cluster Assignment step
|
||||
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x])
|
||||
# Move centroids step
|
||||
new_centroids = np.array([x[assign == k].mean() for k in range(k)])
|
||||
if (new_centroids == centroids).all():
|
||||
centroids = new_centroids
|
||||
break
|
||||
centroids = new_centroids
|
||||
return np.array(centroids), assign
|
||||
|
||||
|
||||
def k_means_bucketing(lengths, buckets):
|
||||
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
|
||||
|
||||
:param lengths: list of int, the length of all samples.
|
||||
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
|
||||
threshold for each bucket (This is usually None.).
|
||||
:return data: 2-level list
|
||||
::
|
||||
|
||||
[
|
||||
[index_11, index_12, ...], # bucket 1
|
||||
[index_21, index_22, ...], # bucket 2
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
bucket_data = [[] for _ in buckets]
|
||||
num_buckets = len(buckets)
|
||||
_, assignments = k_means_1d(lengths, num_buckets)
|
||||
|
||||
for idx, bucket_id in enumerate(assignments):
|
||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
|
||||
bucket_data[bucket_id].append(idx)
|
||||
return bucket_data
|
@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
|
||||
:return:
|
||||
"""
|
||||
if fn_name is not None:
|
||||
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."
|
||||
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."
|
||||
|
||||
parameters = list(inspect.signature(fn).parameters.values())
|
||||
if inspect.ismethod(fn):
|
||||
@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None):
|
||||
return mask
|
||||
|
||||
|
||||
def wait_to_success(fn, no=False):
|
||||
def wait_filepath(path, exist=True):
|
||||
"""
|
||||
等待当 path 的存在状态为 {exist} 时返回
|
||||
|
||||
:param path: 待检测的 path
|
||||
:param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。
|
||||
:return:
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
assert isinstance(path, Path)
|
||||
count = 0
|
||||
while True:
|
||||
sleep(0.01)
|
||||
if (no and not fn()) or (not no and fn()):
|
||||
if path.exists() == exist:
|
||||
break
|
||||
count += 1
|
||||
if count % 1000 == 0:
|
||||
msg = 'create' if exist else 'delete'
|
||||
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")
|
||||
|
||||
|
||||
|
||||
# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
|
||||
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
def synchronize_safe_rm(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
|
||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
if path is None:
|
||||
return
|
||||
if isinstance(path, str):
|
||||
@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]):
|
||||
return
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
_recursive_rm(path)
|
||||
wait_to_success(path.exists, no=True)
|
||||
wait_filepath(path, exist=False)
|
||||
|
||||
|
||||
def _recursive_rm(path: Path):
|
||||
@ -643,6 +665,8 @@ def _recursive_rm(path: Path):
|
||||
def synchronize_mkdir(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
|
||||
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
"""
|
||||
if path is None:
|
||||
return
|
||||
@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
wait_to_success(path.exists)
|
||||
wait_filepath(path, exist=True)
|
||||
|
||||
|
||||
def get_class_that_defined_method(method):
|
||||
|
@ -5,7 +5,6 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
|
@ -50,8 +50,6 @@ class ConllLoader(Loader):
|
||||
|
||||
ConllLoader返回的DataSet的field由传入的headers确定。
|
||||
|
||||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, headers, sep=None, indexes=None, dropna=True):
|
||||
@ -93,6 +91,7 @@ class ConllLoader(Loader):
|
||||
class Conll2003Loader(ConllLoader):
|
||||
r"""
|
||||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。
|
||||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -85,7 +85,7 @@ class MixModule:
|
||||
def test_step(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_step(self, batch):
|
||||
def evaluate_step(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self):
|
||||
|
0
tests/core/callbacks/torch_callbacks/__init__.py
Normal file
0
tests/core/callbacks/torch_callbacks/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.callbacks import TorchGradClipCallback, Callback
|
||||
from fastNLP import Trainer
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
||||
from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args
|
||||
|
||||
|
||||
class CheckClipCallback(Callback):
|
||||
def __init__(self, parameters, clip_type, clip_value):
|
||||
self.parameters = parameters
|
||||
self.clip_type = clip_type
|
||||
self.clip_value = clip_value
|
||||
|
||||
def on_after_optimizers_step(self, trainer, optimizers):
|
||||
for param in self.parameters:
|
||||
if self.clip_type == 'value':
|
||||
assert param.grad.max().item()<=self.clip_value
|
||||
else:
|
||||
assert np.linalg.norm(param.grad.cpu().view(-1).numpy())<=self.clip_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize('accumulation_steps', [1, 3, 5])
|
||||
@pytest.mark.parametrize('fp16', [True, False])
|
||||
@pytest.mark.parametrize('clip_type', ['norm', 'value'])
|
||||
@pytest.mark.parametrize('clip_value', [1, 2])
|
||||
def test_torch_grad_clip_callback(accumulation_steps, fp16, clip_type, clip_value):
|
||||
if not torch.cuda.is_available() and fp16:
|
||||
pytest.skip("No cuda, cannot test fp16.")
|
||||
device = 'cuda' if fp16 else 'cpu'
|
||||
kwargs = get_trainer_args(lr=1, device=device)
|
||||
callbacks = []
|
||||
callbacks.append(TorchGradClipCallback(clip_value=clip_value, clip_type=clip_type))
|
||||
callbacks.append(CheckClipCallback(kwargs['model'].parameters(), clip_type, clip_value))
|
||||
trainer = Trainer(**kwargs, callbacks=callbacks, fp16=fp16)
|
||||
trainer.run()
|
@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.callbacks import TorchWarmupCallback, Callback
|
||||
from fastNLP import Trainer
|
||||
|
||||
from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args
|
||||
|
||||
|
||||
class RecordLrCallback(Callback):
|
||||
def __init__(self):
|
||||
self.lrs = []
|
||||
|
||||
def on_after_optimizers_step(self, trainer, optimizers):
|
||||
self.lrs.append(trainer.driver.optimizers[0].param_groups[0]['lr'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('warmup', [5, 0.1])
|
||||
@pytest.mark.parametrize('schedule', ['constant', 'linear'])
|
||||
@pytest.mark.parametrize('accumulation_steps', [1, 3, 4])
|
||||
def test_torch_warmup_callback(warmup, schedule, accumulation_steps):
|
||||
kwargs = get_trainer_args(lr=0.1, bsz=4)
|
||||
callback = TorchWarmupCallback(warmup, schedule)
|
||||
r_callback = RecordLrCallback()
|
||||
kwargs['callbacks'] = [callback, r_callback]
|
||||
trainer = Trainer(**kwargs, accumulation_steps=accumulation_steps)
|
||||
trainer.run()
|
||||
|
||||
if schedule == 'linear':
|
||||
assert kwargs['optimizers'].param_groups[0]['lr'] <= 0.01
|
||||
elif schedule == 'constant':
|
||||
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr'])
|
||||
|
||||
assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1
|
@ -1,13 +1,11 @@
|
||||
import pytest
|
||||
import os
|
||||
os.environ["FASTNLP_BACKEND"] = "paddle"
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.core.metrics.accuracy import Accuracy
|
||||
from fastNLP.core.callbacks.progress_callback import RichCallback
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
|
||||
|
||||
from paddle.optimizer import Adam
|
||||
from paddle.io import DataLoader
|
||||
@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
@dataclass
|
||||
class MNISTTrainPaddleConfig:
|
||||
class TrainPaddleConfig:
|
||||
num_labels: int = 10
|
||||
feature_dimension: int = 784
|
||||
feature_dimension: int = 10
|
||||
|
||||
batch_size: int = 32
|
||||
batch_size: int = 2
|
||||
shuffle: bool = True
|
||||
validate_every = -5
|
||||
evaluate_every = 2
|
||||
|
||||
driver: str = "paddle"
|
||||
device = "gpu"
|
||||
|
||||
@dataclass
|
||||
class MNISTTrainFleetConfig:
|
||||
num_labels: int = 10
|
||||
feature_dimension: int = 784
|
||||
|
||||
batch_size: int = 32
|
||||
shuffle: bool = True
|
||||
validate_every = -5
|
||||
|
||||
@dataclass
|
||||
class TrainerParameters:
|
||||
model: Any = None
|
||||
optimizers: Any = None
|
||||
train_dataloader: Any = None
|
||||
validate_dataloaders: Any = None
|
||||
input_mapping: Any = None
|
||||
output_mapping: Any = None
|
||||
metrics: Any = None
|
||||
|
||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
|
||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
|
||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
|
||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
|
||||
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
|
||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
|
||||
RichCallback(5)]])
|
||||
@magic_argv_env_context
|
||||
def test_trainer_paddle(
|
||||
driver,
|
||||
@ -60,38 +36,36 @@ def test_trainer_paddle(
|
||||
callbacks,
|
||||
n_epochs=2,
|
||||
):
|
||||
trainer_params = TrainerParameters()
|
||||
|
||||
trainer_params.model = PaddleNormalModel_Classification_1(
|
||||
num_labels=MNISTTrainPaddleConfig.num_labels,
|
||||
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
|
||||
model = PaddleNormalModel_Classification_1(
|
||||
num_labels=TrainPaddleConfig.num_labels,
|
||||
feature_dimension=TrainPaddleConfig.feature_dimension
|
||||
)
|
||||
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
|
||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
|
||||
train_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(6400, 10),
|
||||
batch_size=MNISTTrainPaddleConfig.batch_size,
|
||||
dataset=PaddleRandomMaxDataset(20, 10),
|
||||
batch_size=TrainPaddleConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(1000, 10),
|
||||
batch_size=MNISTTrainPaddleConfig.batch_size,
|
||||
dataset=PaddleRandomMaxDataset(20, 10),
|
||||
batch_size=TrainPaddleConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
trainer_params.train_dataloader = train_dataloader
|
||||
trainer_params.validate_dataloaders = val_dataloader
|
||||
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
|
||||
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
|
||||
train_dataloader = train_dataloader
|
||||
evaluate_dataloaders = val_dataloader
|
||||
evaluate_every = TrainPaddleConfig.evaluate_every
|
||||
metrics = {"acc": Accuracy(backend="paddle")}
|
||||
trainer = Trainer(
|
||||
model=trainer_params.model,
|
||||
model=model,
|
||||
driver=driver,
|
||||
device=device,
|
||||
optimizers=trainer_params.optimizers,
|
||||
train_dataloader=trainer_params.train_dataloader,
|
||||
validate_dataloaders=trainer_params.validate_dataloaders,
|
||||
validate_every=trainer_params.validate_every,
|
||||
input_mapping=trainer_params.input_mapping,
|
||||
output_mapping=trainer_params.output_mapping,
|
||||
metrics=trainer_params.metrics,
|
||||
optimizers=optimizers,
|
||||
train_dataloader=train_dataloader,
|
||||
evaluate_dataloaders=evaluate_dataloaders,
|
||||
evaluate_every=evaluate_every,
|
||||
input_mapping=None,
|
||||
output_mapping=None,
|
||||
metrics=metrics,
|
||||
|
||||
n_epochs=n_epochs,
|
||||
callbacks=callbacks,
|
||||
|
@ -117,12 +117,13 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
@ -133,12 +134,13 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
sampler = RandomSampler(self.dataset, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
|
||||
sampler = RandomSampler(self.dataset, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
@ -171,14 +173,15 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
|
||||
时的表现
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle),
|
||||
)
|
||||
dataloader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
@ -195,12 +198,13 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
|
||||
batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank
|
||||
@ -222,11 +226,12 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
assert replaced_loader is dataloader
|
||||
@ -238,14 +243,15 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
|
||||
的表现
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
|
||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
@ -258,13 +264,14 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
@ -276,16 +283,17 @@ class TestSetDistReproDataloader:
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
@ -293,7 +301,7 @@ class TestSetDistReproDataloader:
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
dist.barrier()
|
||||
|
||||
"""
|
||||
@ -302,13 +310,14 @@ class TestSetDistReproDataloader:
|
||||
"""
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
@ -320,18 +329,19 @@ class TestSetDistReproDataloader:
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
|
||||
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
@ -349,11 +359,12 @@ class TestSetDistReproDataloader:
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from re import S
|
||||
os.environ["FASTNLP_BACKEND"] = "paddle"
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
@ -56,34 +57,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
|
||||
)
|
||||
num_consumed_batches = 2
|
||||
|
||||
# TODO 断点重训完善后在这里迭代几次
|
||||
already_seen_set = set()
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_set.update(batch)
|
||||
|
||||
sampler_states = dataloader.batch_sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
if only_state_dict:
|
||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
else:
|
||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
|
||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
|
||||
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
|
||||
)
|
||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
|
||||
# 1. 检查 optimizer 的状态
|
||||
# TODO optimizer 的 state_dict 总是为空
|
||||
|
||||
# 2. 检查 batch_sampler 是否被正确地加载和替换
|
||||
replaced_loader = states["dataloader"]
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
|
||||
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]
|
||||
|
||||
# 3. 检查 model 的参数是否被正确加载
|
||||
for batch in dataloader:
|
||||
res1 = driver1.validate_step(batch)
|
||||
res2 = driver2.validate_step(batch)
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
|
||||
# 4. 检查 batch_idx
|
||||
# TODO
|
||||
start_batch = load_states.pop('batch_idx_in_epoch')
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_batches = set()
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
left_batches.update(batch)
|
||||
|
||||
assert len(left_batches) + len(already_seen_set) == len(dataset)
|
||||
assert len(left_batches | already_seen_set) == len(dataset)
|
||||
|
||||
|
||||
finally:
|
||||
synchronize_safe_rm(path)
|
||||
|
||||
@ -104,21 +128,36 @@ def test_save_and_load_with_randomsampler(only_state_dict):
|
||||
dataset,
|
||||
batch_sampler=batch_sampler
|
||||
)
|
||||
num_consumed_batches = 2
|
||||
|
||||
# TODO 断点重训完善后在这里迭代几次
|
||||
already_seen_set = set()
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_set.update(batch)
|
||||
|
||||
sampler_states = dataloader.batch_sampler.sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
if only_state_dict:
|
||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
else:
|
||||
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
|
||||
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
|
||||
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
|
||||
)
|
||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
|
||||
# 1. 检查 optimizer 的状态
|
||||
# TODO optimizer 的 state_dict 总是为空
|
||||
|
||||
# 2. 检查 sampler 是否被正确地加载和替换
|
||||
replaced_loader = states["dataloader"]
|
||||
replaced_loader = load_states["dataloader"]
|
||||
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
|
||||
@ -129,60 +168,51 @@ def test_save_and_load_with_randomsampler(only_state_dict):
|
||||
|
||||
# 3. 检查 model 的参数是否被正确加载
|
||||
for batch in dataloader:
|
||||
res1 = driver1.validate_step(batch)
|
||||
res2 = driver2.validate_step(batch)
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
|
||||
# 4. 检查 batch_idx
|
||||
# TODO
|
||||
start_batch = load_states.pop('batch_idx_in_epoch')
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_batches = set()
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
left_batches.update(batch)
|
||||
|
||||
assert len(left_batches) + len(already_seen_set) == len(dataset)
|
||||
assert len(left_batches | already_seen_set) == len(dataset)
|
||||
finally:
|
||||
synchronize_safe_rm(path)
|
||||
|
||||
def test_save_and_load_state_dict(prepare_test_save_load):
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
|
||||
"""
|
||||
测试save和load函数
|
||||
TODO optimizer的state_dict为空,暂时不测试
|
||||
"""
|
||||
try:
|
||||
path = "dict"
|
||||
driver1, driver2, dataloader = prepare_test_save_load
|
||||
|
||||
driver1.save_model(path)
|
||||
driver2.load_model(path)
|
||||
|
||||
for batch in dataloader:
|
||||
batch = driver1.move_data_to_device(batch)
|
||||
res1 = driver1.validate_step(batch)
|
||||
res2 = driver2.validate_step(batch)
|
||||
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
finally:
|
||||
synchronize_safe_rm(path)
|
||||
|
||||
def test_save_and_load_whole_model(prepare_test_save_load):
|
||||
"""
|
||||
测试save和load函数
|
||||
TODO optimizer的state_dict为空,暂时不测试
|
||||
测试 save_model 和 load_model 函数
|
||||
"""
|
||||
try:
|
||||
path = "model"
|
||||
driver1, driver2, dataloader = prepare_test_save_load
|
||||
|
||||
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
|
||||
driver2.load_model(path, only_state_dict=False)
|
||||
if only_state_dict:
|
||||
driver1.save_model(path, only_state_dict)
|
||||
else:
|
||||
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
|
||||
driver2.load_model(path, only_state_dict)
|
||||
|
||||
for batch in dataloader:
|
||||
batch = driver1.move_data_to_device(batch)
|
||||
res1 = driver1.validate_step(batch)
|
||||
res2 = driver2.validate_step(batch)
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
finally:
|
||||
synchronize_safe_rm(path + ".pdiparams")
|
||||
synchronize_safe_rm(path + ".pdiparams.info")
|
||||
synchronize_safe_rm(path + ".pdmodel")
|
||||
|
||||
if only_state_dict:
|
||||
synchronize_safe_rm(path)
|
||||
else:
|
||||
synchronize_safe_rm(path + ".pdiparams")
|
||||
synchronize_safe_rm(path + ".pdiparams.info")
|
||||
synchronize_safe_rm(path + ".pdmodel")
|
||||
|
||||
class TestSingleDeviceFunction:
|
||||
"""
|
||||
@ -199,13 +229,7 @@ class TestSingleDeviceFunction:
|
||||
测试能否运行
|
||||
"""
|
||||
res = self.driver.unwrap_model()
|
||||
|
||||
def test_check_evaluator_mode(self):
|
||||
"""
|
||||
这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素
|
||||
"""
|
||||
self.driver.check_evaluator_mode("validate")
|
||||
self.driver.check_evaluator_mode("test")
|
||||
assert res is self.driver.model
|
||||
|
||||
def test_is_distributed(self):
|
||||
assert self.driver.is_distributed() == False
|
||||
@ -237,44 +261,55 @@ class TestSetDistReproDataloder:
|
||||
|
||||
assert replaced_loader is dataloader
|
||||
|
||||
def test_set_dist_repro_dataloader_with_reproducible_true(self):
|
||||
@pytest.mark.parametrize("shuffle", [True, False])
|
||||
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
|
||||
当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler
|
||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
|
||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
|
||||
if shuffle:
|
||||
# 此时会替换 sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
else:
|
||||
# 此时会替换 batch_sampler
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
|
||||
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
|
||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
|
||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
|
||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
|
||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
|
||||
assert replaced_loader.batch_sampler is dist
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
|
||||
def test_set_dist_repro_dataloader_with_dist_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
|
||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
|
||||
dist = RandomSampler(self.dataset, shuffle=True)
|
||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
|
||||
dist = RandomSampler(self.dataset, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
@ -284,16 +319,21 @@ class TestSetDistReproDataloder:
|
||||
assert replaced_loader.batch_sampler.sampler is dist
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
|
||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
|
||||
应该返回新的 dataloader,且其余各项设置和原来相同
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
|
||||
batch_sampler=RandomBatchSampler(
|
||||
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
|
||||
batch_size=4,
|
||||
drop_last=False,
|
||||
)
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
@ -303,15 +343,16 @@ class TestSetDistReproDataloder:
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
|
||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
|
||||
应该返回新的 dataloader,且其余各项设置和原来相同
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
@ -323,11 +364,11 @@ class TestSetDistReproDataloder:
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
|
||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
|
||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
|
||||
"""
|
||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
|
||||
"""
|
||||
@ -346,9 +387,6 @@ class TestSetDistReproDataloder:
|
||||
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
|
||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
|
||||
|
||||
import time
|
||||
time.sleep(5)
|
||||
|
||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
|
||||
left_idxes = set()
|
||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
|
||||
@ -357,16 +395,29 @@ class TestSetDistReproDataloder:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
|
||||
else:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
replaced_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
# 重新改造 dataloader
|
||||
new_loader = DataLoader(
|
||||
dataset=replaced_loader.dataset,
|
||||
batch_sampler=RandomBatchSampler(
|
||||
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
)
|
||||
)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
else:
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
num_consumed_batches = num_consumed_batches * batch_size
|
||||
if num_consumed_samples_array is not None:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
|
||||
else:
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
replaced_loader.batch_sampler.sampler.set_epoch(0)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
# 重新构造 dataloader
|
||||
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
|
||||
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
|
||||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch)
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
|
||||
|
@ -1,31 +0,0 @@
|
||||
import unittest
|
||||
import random
|
||||
from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from array import array
|
||||
import torch
|
||||
|
||||
from fastNLP.core.samplers.sampler import ReproduceBatchSampler
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
|
||||
|
||||
class SamplerTest(unittest.TestCase):
|
||||
|
||||
def test_sequentialsampler(self):
|
||||
ds = DataSet({'x': [1, 2, 3, 4] * 10})
|
||||
sqspl = SequentialSampler(ds)
|
||||
for idx, inst in enumerate(sqspl):
|
||||
self.assertEqual(idx, inst)
|
||||
|
||||
def test_randomsampler(self):
|
||||
ds = DataSet({'x': [1, 2, 3, 4] * 10})
|
||||
rdspl = RandomSampler(ds)
|
||||
ans = [ds[i] for i in rdspl]
|
||||
self.assertEqual(len(ans), len(ds))
|
||||
|
||||
def test_bucketsampler(self):
|
||||
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
|
||||
sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from fastNLP.envs.set_env import dump_fastnlp_backend
|
||||
from fastNLP.envs.set_backend import dump_fastnlp_backend
|
||||
from tests.helpers.utils import Capturing
|
||||
from fastNLP.core import synchronize_safe_rm
|
||||
|
||||
|
@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback):
|
||||
print("on_train_end")
|
||||
|
||||
def on_train_epoch_begin(self, trainer):
|
||||
if trainer.current_epoch_idx >= 1:
|
||||
if trainer.cur_epoch_idx >= 1:
|
||||
# 触发 on_exception;
|
||||
raise Exception
|
||||
print("on_train_epoch_begin")
|
||||
|
@ -0,0 +1,68 @@
|
||||
|
||||
"""
|
||||
这个文件主要用于提供测试 callback 时的 Trainer 的参数,可以直接使用进行对Trainer进行初始化。只需要再额外传入相应的callback就可以运行
|
||||
|
||||
"""
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.core.metrics import Accuracy
|
||||
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DataSet:
|
||||
def __init__(self, num_samples=1000, num_features=10):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(1000)
|
||||
self.data = torch.randn(num_samples, num_features, generator=g)
|
||||
self.y = self.data.argmax(dim=-1)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {'x': self.data[item], 'target': self.y[item]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, num_features=5):
|
||||
super().__init__()
|
||||
self.mlps = nn.Sequential(
|
||||
nn.Linear(num_features, 20),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 20),
|
||||
nn.Dropout(p=0.3),
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, num_features)
|
||||
)
|
||||
|
||||
def forward(self, x, target):
|
||||
y = self.mlps(x)
|
||||
if self.training:
|
||||
return {'loss': F.cross_entropy(y, target)}
|
||||
return {'pred': y}
|
||||
|
||||
|
||||
def get_trainer_args(num_features=5, num_samples=20, bsz=4, lr=0.1, n_epochs=5, device=None):
|
||||
ds = DataSet(num_samples=num_samples, num_features=num_features)
|
||||
dl = DataLoader(ds, batch_size=bsz)
|
||||
model = Model(num_features=num_features)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
||||
kwargs = {
|
||||
'model': model,
|
||||
'driver': 'torch',
|
||||
'device': device,
|
||||
'optimizers': optimizer,
|
||||
'train_dataloader': dl,
|
||||
'evaluate_dataloaders': dl,
|
||||
'metrics': {'acc': Accuracy()},
|
||||
'n_epochs': n_epochs
|
||||
}
|
||||
|
||||
return kwargs
|
@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer):
|
||||
x = self(x)
|
||||
return {"loss": self.loss_fn(x, y)}
|
||||
|
||||
def validate_step(self, x, y):
|
||||
def evaluate_step(self, x, y):
|
||||
|
||||
x = self(x)
|
||||
return {"pred": x, "target": y.reshape((-1,))}
|
||||
|
Loading…
Reference in New Issue
Block a user