Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
MorningForest 2022-04-15 20:04:53 +08:00
commit a2956b697e
40 changed files with 1307 additions and 1452 deletions

View File

@ -11,7 +11,10 @@ __all__ = [
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback"
"EarlyStopCallback",
"TorchWarmupCallback",
"TorchGradClipCallback"
]
@ -23,4 +26,5 @@ from .progress_callback import choose_progress_callback, ProgressCallback, RichC
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import *

View File

@ -1,16 +1,12 @@
from typing import Union, Callable, Dict, Optional, Any
from abc import ABC
__all__ = [
'Callback',
]
from typing import Union, Callable, Dict, Optional, Any
from .callback_events import Events, EventsList, Filter
from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils.utils import _check_valid_parameters_number
class Callback:
@ -278,135 +274,3 @@ class _CallbackWrapper(Callback):
@property
def callback_name(self):
return self.fn.__name__
class CanItemDataType(ABC):
"""
检测可以进行传输的对象
"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor
def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor
def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置则根据 Trainer 中的 monitor 设置 monitor
同时对于必须要有 monitor 设置的 callback 该函数会进行检查
:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")
def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值如果 monitor 没有直接找到会尝试使用匹配的方式寻找并把匹配到的设置到 self._real_monitor 属性上
:param results:
:return: 如果为 None 表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")
self._real_monitor = use_monitor
return monitor_value
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的
:param monitor_value: 待检查的 monitor_value 如果为 None 返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好则将其保存下来
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better
def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中是否monitor_value1的结果更好
:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better
@property
def monitor_name(self):
"""
返回 monitor 的名字如果 monitor 是个 callable 的函数则返回该函数的名称
:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor而不能是real_monitor因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name

View File

@ -10,9 +10,9 @@ from copy import deepcopy
import fastNLP
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback):
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder)
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建
synchronize_mkdir(folder)
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,

View File

@ -4,7 +4,7 @@ __all__ = [
from typing import Dict, Union, Callable
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException

View File

@ -0,0 +1,189 @@
__all__ = [
'HasMonitorCallback',
'ExecuteOnceBetterMonitor'
]
from typing import Dict, Union, Any
from abc import ABC
from fastNLP.core.utils import apply_to_collection
from fastNLP.core.callbacks import Callback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.core.utils.utils import _check_valid_parameters_number
class CanItemDataType(ABC):
"""
检测可以进行传输的对象
"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class HasMonitorCallback(Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""
callback 不直接进行使用作为其它相关 callback 的父类使用如果 callback 有使用 monitor 可以继承该函数里面实现了
1判断monitor合法性2在需要时 根据trainer的monitor设置自己的monitor名称
:param monitor: 监控的 metric 如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 如果为 None将尝试使用 Trainer 设置的 monitor 也可以传入一个函数接受参数为 evaluation 的结
(字典类型)返回一个 float 值作为 monitor 的结果
:param larger_better: monitor 是否时越大越好
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置如果设置为 True 且没检测到设置 monitor 会报错
"""
self.set_monitor(monitor, larger_better)
self.must_have_moinitor = must_have_monitor
def set_monitor(self, monitor, larger_better):
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better)
if larger_better:
self.monitor_value = float('-inf')
else:
self.monitor_value = float('inf')
self._real_monitor = self.monitor
def on_after_trainer_initialized(self, trainer, driver):
"""
如果本身的 monitor 没有设置则根据 Trainer 中的 monitor 设置 monitor
同时对于必须要有 monitor 设置的 callback 该函数会进行检查
:param trainer:
:param driver:
:return:
"""
if self.monitor is None and trainer.monitor is not None:
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better)
if self.must_have_moinitor and self.monitor is None:
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.")
def get_monitor_value(self, results:Dict)->Union[float, None]:
"""
获取 monitor 的值如果 monitor 没有直接找到会尝试使用匹配的方式寻找并把匹配到的设置到 self._real_monitor 属性上
:param results:
:return: 如果为 None 表明此次没有找到合适的monitor
"""
if len(results)==0:
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor,
res=results)
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")
self._real_monitor = use_monitor
return monitor_value
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True):
"""
检测 monitor_value 是否是更好的
:param monitor_value: 待检查的 monitor_value 如果为 None 返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好则将其保存下来
:return:
"""
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better
def is_better_results(self, results, keep_if_better=True):
"""
检测给定的 results 是否比上一次更好如果本次 results 中没有找到相关的monitor 返回 False
:param results: on_valid_ends() 接口中传入的 evaluation 结果
:param keep_if_better: 当返回为 True 是否保存到 self.monitor_value
:return:
"""
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return False
return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better)
def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中是否monitor_value1的结果更好
:param monitor_value1:
:param monitor_value2:
:return:
"""
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
return better
@property
def monitor_name(self):
"""
返回 monitor 的名字如果 monitor 是个 callable 的函数则返回该函数的名称
:return:
"""
if callable(self.monitor):
try:
monitor_name = self.monitor.__qualname__
except:
monitor_name = self.monitor.__name__
elif self.monitor is None:
return None
else:
# 这里是能是monitor而不能是real_monitor因为用户再次运行的时候real_monitor被初始化为monitor了
monitor_name = str(self.monitor)
return monitor_name
class ExecuteOnceBetterMonitor(HasMonitorCallback):
def __init__(self, monitor, larger_better, execute_fn):
"""
当监控的 monitor 结果更好的时候调用 execute_fn 函数
:param monitor: 监控的 metric 如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 如果为 None将尝试使用 Trainer 设置的 monitor 也可以传入一个函数接受参数为 evaluation 的结
(字典类型)返回一个 float 值作为 monitor 的结果
:param larger_better: monitor 是否时越大越好
:param execute_fn: 一个可执行的函数不接受任何参数不反回值 monitor 取得更好结果的时候会调用
"""
super().__init__(monitor, larger_better, must_have_monitor=True)
_check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn')
self.execute_fn = execute_fn()
def on_validate_end(self, trainer, results):
if self.is_better_results(results):
self.execute_fn()

View File

@ -4,7 +4,7 @@ __all__ = [
import os
from typing import Optional, Callable, Union
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from io import BytesIO
import shutil
@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback):
self.get_monitor_value(sanity_check_res)
def on_validate_end(self, trainer, results):
monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_save_fn=self.model_save_fn)

View File

@ -8,7 +8,7 @@ __all__ = [
'RichCallback'
]
from .callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger

View File

@ -0,0 +1,8 @@
__all__ = [
'TorchWarmupCallback',
'TorchGradClipCallback'
]
from .torch_lr_sched_callback import TorchWarmupCallback
from .torch_grad_clip_callback import TorchGradClipCallback

View File

@ -0,0 +1,52 @@
__all__ = [
'TorchGradClipCallback'
]
from ..callback import Callback
class TorchGradClipCallback(Callback):
def __init__(self, clip_value=1, clip_type='norm', parameters=None):
r"""
在每次 optimizer update 之前将 parameter 进行 clip
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]clip_value应该为正数
:param str clip_type: 支持'norm', 'value'
两种::
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
2 'value', 将gradient限制在[-clip_value, clip_value],
小于-clip_value的gradient被赋值为-clip_value;
大于clip_value的gradient被赋值为clip_value.
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得
如果为None则默认对 Trainer optimizers 中所有参数进行梯度裁剪
"""
super().__init__()
from torch import nn
if clip_type == 'norm':
self.clip_fun = nn.utils.clip_grad_norm_
elif clip_type == 'value':
self.clip_fun = nn.utils.clip_grad_value_
else:
raise ValueError("Only supports `norm` or `value` right now.")
if parameters is not None:
self.parameters = list(parameters)
else:
self.parameters = None
self.clip_value = clip_value
def on_after_trainer_initialized(self, trainer, driver):
assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \
f"related drivers for now."
parameters = []
for optimizer in trainer.driver.optimizers:
for param_group in optimizer.param_groups:
parameters.extend(param_group['params'])
self.parameters = parameters
assert len(self.parameters), "There is no parameters need to be clipped."
def on_before_optimizers_step(self, trainer, optimizers):
for optimizer in trainer.driver.optimizers:
trainer.driver.grad_scaler.unscale_(optimizer)
self.clip_fun(self.parameters, self.clip_value)

View File

@ -0,0 +1,58 @@
__all__ = [
'TorchWarmupCallback'
]
import math
from ..callback import Callback
class TorchWarmupCallback(Callback):
def __init__(self, warmup=0.1, schedule='constant'):
r"""
调整 learning rate callback 仅在实际发生参数更新的情况下
:param int,float warmup: 如果warmup为int则在该step之前learning rate根据schedule的策略变化; 如果warmup为float
如0.1, 则前10%的step是按照schedule策略调整learning rate
:param str schedule: 以哪种方式调整
linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0
constant前warmup的step上升到指定learning rate后面的step保持learning rate.
"""
super().__init__()
self.warmup = max(warmup, 0.)
self.initial_lrs = [] # 存放param_group的learning rate
if schedule == 'constant':
self.get_lr = self._get_constant_lr
elif schedule == 'linear':
self.get_lr = self._get_linear_lr
else:
raise RuntimeError("Only support 'linear', 'constant'.")
def _get_constant_lr(self, progress):
if progress <self.warmup:
return progress /self.warmup
return 1
def _get_linear_lr(self, progress):
if progress <self.warmup:
return progress /self.warmup
return max((progress - 1.) / (self.warmup - 1.), 0.)
def on_train_begin(self, trainer):
self.t_steps = trainer.total_batches
if self.warmup >1:
self.warmup = self.warmup / self.t_steps
self.t_steps = max(2, self.t_steps) # 不能小于2
# 防止 t_steps 不能整除 accumulation_steps
self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps
# 获取param_group的初始learning rate
for optimizer in trainer.driver.optimizers:
for group in optimizer.param_groups:
self.initial_lrs.append(group['lr'])
def on_before_optimizers_step(self, trainer, optimizers):
# 这里需要加 accumulation_steps 是防止 lr 从 0 开始
progress = (trainer.global_forward_batches + trainer.accumulation_steps) / self.t_steps
for optimizer in trainer.driver.optimizers:
for lr, group in zip(self.initial_lrs, optimizer.param_groups):
group['lr'] = lr * self.get_lr(progress)

View File

@ -219,10 +219,10 @@ class Trainer(TrainerEventTrigger):
""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.")
raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.")
if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.")
self.evaluator = None
self.monitor = monitor

View File

@ -129,7 +129,7 @@ class Driver(ABC):
@property
def optimizers(self) -> List:
r"""
如下所示driver 返回的 optimizers 一定是一个 List如果用户直接向 Trainer 传入一个单独的 optimzer我们会使用一个 List 将其
如下所示driver 返回的 optimizers 一定是一个 List如果用户直接向 Trainer 传入一个单独的 optimizer我们会使用一个 List 将其
包裹
:return: List[optimizer0, optimizer1, optimizer2, ...]

View File

@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Optional, Union, Callable, Dict, Tuple
from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver):
return self._data_device
return self.model_device
def train_step(self, batch):
return self._train_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
pass
def validate_step(self, batch):
return self._validate_step(batch)
def test_step(self, batch):
return self._test_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
pass
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):

View File

@ -1,9 +1,11 @@
from typing import Dict, Union
from typing import Dict, Union, Tuple, Callable, Optional
from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger
if _NEED_IMPORT_JITTOR:
import jittor
@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver):
self.global_rank = 0
self.world_size = 1
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
model = self.unwrap_model()
self._train_signature_fn = model.execute
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.execute
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.execute
def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
def step(self):
"""
jittor optimizers 的step函数可以传入参数loss
@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver):
for optimizer in self.optimizers:
optimizer.zero_grad()
def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)
return fn(batch)
def test_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
def unwrap_model(self):
return self.model

View File

@ -0,0 +1,376 @@
import io
import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass
from fastNLP.core.utils import apply_to_collection
def _validate_output_list_for_rank(my_rank, dst, gather_list):
if dst == my_rank:
if not gather_list:
raise ValueError(
"Argument ``gather_list`` must be specified on destination rank."
)
elif gather_list:
raise ValueError(
"Argument ``gather_list`` must NOT be specified "
"on non-destination ranks."
)
def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank
Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.
Args:
obj (Any): Input object. Must be picklable.
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
should be correctly sized as the size of the group for this
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
output of the collective.
.. note:: Note that this API differs slightly from the gather collective
since it does not provide an async_op handle and thus will be a blocking
call.
.. note:: Note that this API is not supported when using the NCCL backend.
.. warning::
:func:`gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.
Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
gather_objects[dist.get_rank()],
output if dist.get_rank() == 0 else None,
dst=0
)
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]
"""
if group is None:
group = DEFAULT_TORCH_GROUP
if dist.distributed_c10d._rank_not_in_group(group):
return
# Ensure object_gather_list is specified appopriately.
my_rank = dist.get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
# 防止 unpickle 的时候出现在了发送的 gpu 上。
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
input_tensor, local_size = _object_to_tensor(obj)
group_backend = dist.get_backend(group)
current_device = torch.device("cpu")
is_nccl_backend = group_backend == dist.Backend.NCCL
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes. An all-gather is needed here despite this being a
# gather, since each rank needs to broadcast a tensor of the same (maximal)
# size.
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
# Avoid populating output tensors if the result won't be gathered on this rank.
if my_rank == dst:
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
# All ranks call gather with equal-sized tensors.
dist.gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
group=group,
)
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
def _object_to_tensor(obj, device=None):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
if device is not None:
byte_tensor = byte_tensor.to(device)
local_size = local_size.to(device)
return byte_tensor, local_size
def _tensor_to_object(tensor, tensor_size):
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
# src rank send to all other ranks
size = torch.LongTensor([0]).to(device)
if cur_rank == src:
world_size = dist.get_world_size(group=group)
tensor, size = _object_to_tensor(obj)
tensor = tensor.to(device)
size = size.to(device)
# 首先同步 obj 的 size 的信息;
dist.broadcast(size, src, group=group)
for subrank in range(world_size):
if subrank != src:
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
else:
dist.broadcast(size, src, group=group)
tensor = torch.ByteTensor([0] * size).to(device)
dist.recv(tensor=tensor, src=src, group=group, tag=tag)
return _tensor_to_object(tensor.cpu(), size)
def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作对于非 tensor 类型的数据通过 pickle 序列化再反序列化的方式进行传输
example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]
:param obj: 任意结构的数据如果为 tensor 需要保证每个显卡上的 tensor 的形状是一样的如果传入的是非 tensor 对象都将直接进行
序列化之后进行传输
:param device: 当前该参数无意义
:param group:
:return: 返回的结果是 [obj0, obj1, ...]其中 obj_i 即为第 i rank 上的 obj
"""
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
else:
objs = [None for _ in range(dist.get_world_size(group))]
# 防止 unpickle 的时候弄到发送的 gpu 上了
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
dist.all_gather_object(objs, obj, group=group)
else:
objs = all_gather_object(objs, obj, group=group)
return objs
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
"""
src 上的 obj 对象广播到其它 rank
:param obj:
:param src:
:param device:
:param group:
:return:
"""
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
if cur_rank!=src:
get_obj = [None]
dist.broadcast_object_list(get_obj, src=src, group=group)
return get_obj[0]
else:
dist.broadcast_object_list([obj], src=src, group=group)
return obj
if device is None:
device = torch.cuda.current_device()
if cur_rank == src:
tensor, size = _object_to_tensor(obj, device=device)
else:
size = torch.LongTensor([0]).to(device)
dist.broadcast(size, src=src, group=group)
if cur_rank != src:
tensor = torch.empty(
size.int().item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=device
)
dist.broadcast(tensor, src=src, group=group)
return _tensor_to_object(tensor, tensor_size=size.item())
def _check_for_nccl_backend(group):
pg = group or dist.distributed_c10d._get_default_group()
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, _ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
dist.is_nccl_available() and
isinstance(pg, dist.ProcessGroupNCCL)
)
def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码使得可以版本兼容低版本的 pytorch
Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.
Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.
.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.
.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.
.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.
Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
"""
if dist.distributed_c10d._rank_not_in_group(group):
return
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
else:
current_device = torch.cuda.current_device()
input_tensor, local_size = _object_to_tensor(obj, device=current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
return object_list

View File

@ -1,13 +1,12 @@
import os
import shutil
from functools import partial
from typing import List, Union, Optional, Dict
from typing import List, Union, Optional, Dict, Tuple, Callable
from .paddle_driver import PaddleDriver
from .fleet_launcher import FleetLauncher
from .utils import (
_FleetWrappingModel,
ForwardState,
_MODE_PARAMETER,
get_device_from_visible,
reset_seed,
replace_sampler,
@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE:
__all__ = [
"PaddleFleetDriver",
]
# if os.path.exists(self.gloo_rendezvous_dir):
# shutil.rmtree(self.gloo_rendezvous_dir)
class PaddleFleetDriver(PaddleDriver):
def __init__(
self,
@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver):
# 我们就直接将 model_device 置为 None
self._model_device = None
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)
model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# 当参数 `device` 为 None 时并且该参数不为 None表示将对应的数据移到指定的机器上
self._data_device = kwargs.get("data_device", None)
if self._data_device is not None:
@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver):
self.world_size = None
self.global_rank = 0
self._configured = False # 防止重复调用 configure_ddp() 函数使用
self._has_setup = False # 防止重复调用 setup() 函数
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver):
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹;
def setup(self):
"""
在主进程拉起其它子进程将主进程作为rank 0
@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver):
dist.barrier()
def configure_fleet(self):
if not self._configured and not isinstance(self.model, DataParallel):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
self.model = DataParallel(
_FleetWrappingModel(self.model),
**self._fleet_kwargs
)
self._has_fleetwrapped = True
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)
self._configured = True
def on_exception(self):
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
super().on_exception()
@property
def world_size(self) -> int:
@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver):
return self._data_device
return self.model_device
def train_step(self, batch):
return self._train_step(batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if self._has_fleetwrapped:
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn,
wo_auto_param_call=self.wo_auto_param_call)
else:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return fn(batch)
def validate_step(self, batch):
return self._validate_step(batch)
def get_model_call_fn(self, fn: str) -> Tuple:
model = self.unwrap_model()
if self._has_fleetwrapped:
if hasattr(model, fn):
fn = getattr(model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")
else:
if hasattr(model, fn):
logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements "
f"the `{fn}` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
elif fn not in {"train_step", "evaluate_step"}:
raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a "
"`DistributedDataParallel` model, which means that we will only call model.forward "
"function when we are in forward propagation.")
def test_step(self, batch):
return self._test_step(batch)
return self.model, model.forward
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver):
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
def backward(self, loss):
self.grad_scaler.scale(loss).backward()
def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
def is_global_zero(self):
return self.global_rank == 0
@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver):
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")
def broadcast_object(self, obj, src:int=0, group=None, **kwargs):
"""
src 端将 obj 对象可能是 tensor 可能是 object 发送到 dst 如果是非 tensor 的对象会尝试使用 pickle 进行打包进行
传输然后再 dst 处再加载回来仅在分布式的 driver 中有实际意义
:param obj: obj可能是 Tensor 嵌套类型的数据
:param int src: source global rank
:param int dst: target global rank可以是多个目标 rank
:param group: 所属的 group
:param kwargs:
:return: 如果当前不是分布式 driver 直接返回输入的 obj 如果当前 rank 是接收端 global rank 包含在了 dst 则返回
接收到的参数如果是 source 端则返回发射的内容既不是发送端又不是接收端则返回 None
"""
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
def all_gather(self, obj, group) -> List:
"""
obj 互相传送到其它所有的 rank 其中 obj 可能是 Tensor也可能是嵌套结构的 object 如果不是基础类型的数据尝试通过
pickle 进行序列化接收到之后再反序列化
example:
obj = {
'a': [1, 1],
'b': [[1, 2], [1, 2]],
'c': {
'd': [1, 2]
}
}
->
[
{'a': 1, 'b':[1, 2], 'c':{'d': 1}},
{'a': 1, 'b':[1, 2], 'c':{'d': 2}}
]
:param obj: 需要传输的对象在每个rank上都应该保持相同的结构
:param group:
:return:
"""
return
return fastnlp_paddle_all_gather(obj, group=group)

View File

@ -71,6 +71,14 @@ class PaddleDriver(Driver):
for optimizer in self.optimizers:
optimizer.clear_grad()
def backward(self, loss):
self.grad_scaler.scale(loss).backward()
def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
r"""
@ -115,28 +123,6 @@ class PaddleDriver(Driver):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")
def check_evaluator_mode(self, mode: str):
r"""
因为我们在具体的 driver evaluate_step test_step 的逻辑是如果模型没有实现本函数那么就去检测模型是否实现了另一个函数
因此如果用户的 evaluator evaluate_fn validate但是传入的 model 却没有实现 evaluate_step 函数而是实现了 test_step 函数那么
我们应当提醒用户这一行为
"""
model = self.unwrap_model()
if mode == "validate":
if not hasattr(model, "evaluate_step"):
if hasattr(model, "test_step"):
logger.warning(
"Your model does not have 'evaluate_step' method but has 'test_step' method, but you"
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for"
"'evaluate_step'.")
else:
if not hasattr(model, "test_step"):
if hasattr(model, "evaluate_step"):
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for"
"'test_step'.")
@staticmethod
def tensor_to_numeric(tensor, reduce=None):
r"""
@ -258,20 +244,21 @@ class PaddleDriver(Driver):
if hasattr(sampler, "state_dict") and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None)
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
try:
sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None就只有损失精度了
pass
assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report."
if isinstance(sampler, ReproducibleSampler):
# 如果是 sampler 的话,需要计算出实际的 sample 数目
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
states['sampler_states'] = sampler_states
else:
raise RuntimeError(
"The sampler has no `state_dict()` method, it will fail to recover to the specific batch.")
states["sampler_states"] = sampler_states
# 2. 保存模型的状态;
if should_save_model:

View File

@ -1,5 +1,5 @@
import os
from typing import Optional, Dict, Union
from typing import Optional, Dict, Union, Callable, Tuple
from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
@ -11,16 +11,19 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproducibleSampler,
RandomSampler,
re_instantiate_sampler,
)
from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
from paddle.fluid.reader import _DatasetKind
__all__ = [
@ -28,109 +31,57 @@ __all__ = [
]
class PaddleSingleDriver(PaddleDriver):
def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs):
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs):
if isinstance(model, DataParallel):
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`")
cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None)
if cuda_visible_devices == "":
device = "cpu"
logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to"
"use `cpu` instead of `gpu` device.")
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs)
if device is None:
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")
if device != "cpu":
if isinstance(device, int):
device_id = device
else:
device_id = get_paddle_device_id(device)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
self.model_device = get_paddle_gpu_str(device)
self.local_rank = 0
self.global_rank = 0
self.world_size = 1
if isinstance(model, paddle.DataParallel):
# 注意这里的 unwrap_model 调用的是具体子类的方法;
model = self.unwrap_model()
if hasattr(model, "train_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `train_step` method, which we can not call actually, we will "
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward
if hasattr(model, "evaluate_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `evaluate_step` method, which we can not call actually, we "
"will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward
if hasattr(model, "test_step"):
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also "
"implements the `test_step` method, which we can not call actually, we will "
"call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
# 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的;
model = self.unwrap_model()
self._train_signature_fn = model.forward
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.test_step
else:
self._validate_step = self.model
model = self.unwrap_model()
self._validate_signature_fn = model.forward
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.evaluate_step
else:
self._test_step = self.model
model = self.unwrap_model()
self._test_signature_fn = model.forward
def setup(self):
device = self.model_device
if device != "cpu":
device_id = get_paddle_device_id(device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = get_device_from_visible(device, output_type=str)
device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)
def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict我们就默认帮其做参数匹配否则就直接传入到 `train_step` 函数中,让用户自己处理;
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._train_step(batch)
return fn(batch)
def backward(self, loss):
self.grad_scaler.scale(loss).backward()
def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
def validate_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._validate_step(batch)
def test_step(self, batch) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
def move_data_to_device(self, batch: 'paddle.Tensor'):
r"""
@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler)
if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
if isinstance(args.sampler, paddle.io.RandomSampler):
# 如果本来就是随机的,直接替换
sampler = RandomSampler(args.sampler.data_source)
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
else:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

View File

@ -11,7 +11,6 @@ from typing import Dict, Optional, Union
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to
from fastNLP.core.samplers import RandomSampler
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger
@ -87,8 +86,6 @@ class ForwardState(IntEnum):
TEST = 2
PREDICT = 3
_MODE_PARAMETER = "forward_state"
class _FleetWrappingModel(Layer):
"""
参考_DDPWrappingModelpaddle的分布式训练也需要用paddle.nn.DataParallel进行包装采用和
@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer):
super(_FleetWrappingModel, self).__init__()
self.model = model
if isinstance(model, paddle.DataParallel):
model = model._layers
if hasattr(model, "train_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = self.model
self._train_signature_fn = model.forward
if hasattr(model, "evaluate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `evaluate_step` method, which we can not call actually, "
"we will call `forward` function instead of `evaluate_step` and you should note that.")
self._validate_step = self.model
self._validate_signature_fn = model.forward
if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = self.model
self._test_signature_fn = model.forward
else:
if hasattr(model, "train_step"):
self._train_step = model.train_step
self._train_signature_fn = None
else:
self._train_step = model
self._train_signature_fn = model.forward
if hasattr(model, "evaluate_step"):
self._validate_step = model.validate_step
self._validate_signature_fn = None
elif hasattr(model, "test_step"):
self._validate_step = model.test_step
self._validate_signature_fn = None
else:
self._validate_step = model
self._validate_signature_fn = model.forward
if hasattr(model, "test_step"):
self._test_step = model.test_step
self._test_signature_fn = None
elif hasattr(model, "evaluate_step"):
self._test_step = model.validate_step
self._test_signature_fn = None
else:
self._test_step = model
self._test_signature_fn = model.forward
def forward(self, batch, **kwargs) -> Dict:
forward_state = kwargs.pop(_MODE_PARAMETER)
fn = kwargs.pop("fastnlp_fn")
signature_fn = kwargs.pop("fastnlp_signature_fn")
wo_auto_param_call = kwargs.pop("wo_auto_param_call")
if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.")
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
raise NotImplementedError("You should direct a concrete evaluate_fn.")
return fn(batch)
class DummyGradScaler:
"""

View File

@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None:
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)

View File

@ -37,7 +37,12 @@ class TorchSingleDriver(TorchDriver):
super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs)
if device is None:
raise ValueError("Parameter `device` can not be None in `TorchSingleDriver`.")
logger.debug("device is not set, fastNLP will try to automatically get it.")
try:
device = next(model.parameters()).device
assert isinstance(device, torch.device)
except:
raise ValueError("fastNLP cannot get device automatically, please set device explicitly.")
self.model_device = device
@ -70,6 +75,7 @@ class TorchSingleDriver(TorchDriver):
return self.model, model.forward
else:
# TODO 这种直接调用模型某个接口的方法无法触发hook也许需要做一个warning如果用户有钩子提醒他train_step无法触发。
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):

View File

@ -25,7 +25,7 @@ __all__ = [
from .utils import optimizer_state_to_device
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
@ -224,6 +224,11 @@ class TorchDriver(Driver):
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy测试是不需要的
# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict
logger.debug("Save optimizer state dict")
states["optimizers_state_dict"] = optimizers_state_dict
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
@ -232,7 +237,7 @@ class TorchDriver(Driver):
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
# 1. 加载 optimizers 的状态;
optimizers_state_dict = states["optimizers_state_dict"]
optimizers_state_dict = states.pop("optimizers_state_dict")
for i in range(len(self.optimizers)):
optimizer: torch.optim.Optimizer = self.optimizers[i]
optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"])
@ -244,26 +249,37 @@ class TorchDriver(Driver):
res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu')
if only_state_dict:
model.load_state_dict(res)
logger.debug("Load model state dict.")
logger.debug("Load model state dict...")
else:
model.load_state_dict(res.state_dict())
logger.debug("Load model.")
logger.debug("Load model...")
# 3. 恢复 sampler 的状态;
# 3. 加载fp16的状态
if 'grad_scaler_state_dict' in states:
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")
# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
sampler.load_state_dict(states.pop('sampler_states'))
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch

View File

@ -1,6 +1,7 @@
from typing import Optional, Dict, Union, Callable
from typing import Optional, Dict, Union, Callable, Tuple
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.core.utils.utils import _get_fun_msg
if _NEED_IMPORT_PADDLE:
@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver):
elif self._data_device is not None:
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
if hasattr(self.model, "train_step"):
self._train_step = self.model.train_step
self._train_signature_fn = None
else:
self._train_step = self.model
self._train_signature_fn = self.model.forward
if hasattr(self.model, "evaluate_step"):
self._validate_step = self.model.evaluate_step
self._validate_signature_fn = None
elif hasattr(self.model, "test_step"):
self._validate_step = self.model.test_step
self._validate_signature_fn = self.model.forward
else:
self._validate_step = self.model
self._validate_signature_fn = self.model.forward
if hasattr(self.model, "test_step"):
self._test_step = self.model.test_step
self._test_signature_fn = None
elif hasattr(self.model, "evaluate_step"):
self._test_step = self.model.evaluate_step
self._test_signature_fn = self.model.forward
else:
self._test_step = self.model
self._test_signature_fn = self.model.forward
def setup(self):
if self.model_device is not None:
paddle.device.set_device(self.model_device.replace("cuda", "gpu"))
@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver):
f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, "
f"not {type(each_optimizer)}.")
def train_step(self, batch) -> Dict:
if isinstance(batch, Dict):
return auto_param_call(self._train_step, batch)
else:
return self._train_step(batch)
def step(self):
for optimizer in self.optimizers:
optimizer.step()
@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver):
else:
raise ValueError("Unknown optimizers type.")
def validate_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._validate_step, batch)
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)
return fn(batch)
def test_step(self, batch):
if isinstance(batch, Dict):
return auto_param_call(self._test_step, batch)
def get_model_call_fn(self, fn: str) -> Tuple:
if hasattr(self.model, fn):
fn = getattr(self.model, fn)
if not callable(fn):
raise RuntimeError(f"The `{fn}` attribute is not `Callable`.")
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
else:
return self._test_step(batch)
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")
def predict_step(self, batch):
if isinstance(batch, Dict):

View File

@ -1,9 +1,4 @@
__all__ = [
'BucketSampler',
'SortedSampler',
'ConstTokenNumSampler',
'ConstantTokenNumSampler',
'MixSampler',
'DopedSampler',
'MixSequentialSampler',
@ -26,7 +21,6 @@ __all__ = [
"re_instantiate_sampler"
]
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler

View File

@ -1,728 +0,0 @@
r"""
sampler 子类实现了 fastNLP 所需的各种采样器
"""
__all__ = [
"BucketSampler",
"SortedSampler",
'ConstTokenNumSampler',
"ConstantTokenNumSampler",
]
from itertools import chain
from typing import List, Iterable
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
# class DopedSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler其功能是将传入的datasets的list列表混合采样组成一个个batch返回。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
# ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, drop_last: bool = False) -> None:
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
# self.batch_size = batch_size
# self.drop_last = drop_last
# self.ds_ratio = ds_ratio
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
#
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
# else:
# self.sampler = sampler
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
# self.ds_ratio = ds_ratio
# elif isinstance(ds_ratio, List):
# if not all(item >= 0 for item in ds_ratio):
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(ds_ratio))
# self.ds_ratio = ds_ratio
# else:
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
#
# def __iter__(self):
# samplers, index = [], 0
# if isinstance(self.sampler, List):
# for idx, sampler in enumerate(self.sampler):
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
# index += len(sampler)
#
# def __len__(self):
# lens = 0
# max_len, ds_len = 0, 0
# if self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# max_len = min(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# elif self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# max_len = max(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# if self.ds_ratio is None:
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif self.ds_ratio == 'truncate_to_least' or self.ds_ratio == 'pad_to_most':
# for i in range(ds_len):
# if self.drop_last:
# lens += max_len // self.batch_size
# else:
# lens += (max_len + self.batch_size - 1) // self.batch_size
# return lens
#
# def demo(self):
# indexes = np.array([0]*self.batch_size + [1]*self.batch_size + [2]*self.batch_size)
# shift = np.array([0]*self.batch_size + [len(ds1)]*self.batch_size + [len(ds1)+len(ds2)]*self.batch_size)
# buffer = np.zeros(self.batch_size*self.num_ds, dtype=int)
# select_sampler = np.random.randint(0, self.batch_size*self.num_ds, num_sample=self.batch_size)
# select_indices = buffer[select_sampler] + shift[select_sampler]
# num_1 = (indexes[select_sampler]==0).sum()
#
# class MixSequentialSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler其功能是将传入的datasets的list列表顺序采样并返回index只有处理了上一个dataset才会处理下一个。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = None,
# sampler: Union[List[Sampler], Dict[str, Sampler], None] = None,
# drop_last: bool = False) -> None:
# """
#
# :param dataset: 实现了__getitem__和__len__的数据容器列表
# :param batch_size: 对应dataset的批次大小可以为list或者为int当为int时默认所有dataset
# :param sampler: 实例化好的sampler每个dataset对应一个sampler对象
# :param drop_last: 是否去掉最后一个batch的数据其长度小于batch_size
# """
# # 如果dataset为Dict则其他参数如collate_fn必须为Dict或者Callable,
# if isinstance(dataset, Dict) and isinstance(sampler, List):
# raise ValueError(f"{sampler} must be dict")
#
# # 判断batch_size是否大于等于0
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
#
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
# self.batch_size = batch_size
# self.drop_last = drop_last
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
#
# def __iter__(self) -> Iterable[List[int]]:
# """
# 按照dataset的顺序采样打包成一个batch后返回
# :return:
# """
# index = 0
# batch = []
# if isinstance(self. sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# for idx in sampler:
# batch.append(idx + index)
# if len(batch) == self.batch_size:
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
# batch = []
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# for idx in sampler:
# batch.append(idx + index)
# if len(batch) == self.batch_size:
# yield batch
# batch = []
# if len(batch) > 0 and not self.drop_last:
# yield batch
# batch = []
# index += len(sampler)
#
# def __len__(self) -> int:
# lens = 0
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for _, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# return lens
# class PollingSampler(Sampler):
# """
# 定制给MixDataLoader的BatchSampler其功能是将传入的datasets的list列表轮流采样并返回index处理了上个dataset的一个batch后会处理下一个。
# """
#
# def __init__(self, dataset: Union[List, Dict], batch_size: int = 16,
# sampler: Union[List[Sampler], Dict[str, Sampler]] = None,
# drop_last: bool = False, ds_ratio="pad_to_most") -> None:
# """
#
# :param dataset: 实现了__getitem__和__len__的数据容器列表
# :param batch_size: 对应dataset的批次大小可以为list或者为int当为int时默认所有dataset
# :param sampler: 实例化好的sampler每个dataset对应一个sampler对象
# :param drop_last: 是否去掉最后一个batch的数据其长度小于batch_size
# :param ds_ratio: 当ds_ratio=None时候 轮流采样dataset列表直至所有的数据集采样完当ds_ratio='truncate_to_least'时,
# 以dataset列表最短的ds为基准长的数据集会被截断当ds_ratio='pad_to_most'时以dataset列表最长ds为基准短的数据集会被重采样
# """
# # 如果dataset为Dict则其他参数如collate_fn必须为Dict或者Callable,
# if isinstance(dataset, Dict) and isinstance(sampler, List):
# raise ValueError(f"{sampler} must be dict")
# if isinstance(dataset, List) and isinstance(sampler, Dict):
# raise ValueError(f"{sampler} must be list")
# # 判断batch_size是否大于等于0
# if batch_size <= 0:
# raise ValueError("batch_size should be a positive integer value, "
# "but got batch_size={}".format(batch_size))
#
# if not isinstance(drop_last, bool):
# raise ValueError("drop_last should be a boolean value, but got "
# "drop_last={}".format(drop_last))
#
# self.batch_size = batch_size
# self.drop_last = drop_last
# if sampler is None:
# if isinstance(dataset, List):
# self.sampler = [SequentialSampler(ds) for ds in dataset]
# elif isinstance(dataset, Dict):
# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()}
#
# elif isinstance(sampler, List):
# if len(sampler) != len(dataset):
# raise ValueError("the length of sampler != the length of sampler")
# self.sampler = sampler
# else:
# self.sampler = sampler
# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None:
# self.ds_ratio = ds_ratio
# else:
# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None")
#
# def __iter__(self) -> Iterable[List[int]]:
# # index是数据集下标基址 pointer指向数据集列表的某个数据集
# index, pointer, samplers, flag = 0, 0, [], False
#
# if isinstance(self.sampler, List):
# for idx, sampler in enumerate(self.sampler):
# samplers.append((iter(sampler), self.batch_size, index, 0, idx))
# index += len(sampler)
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# samplers.append((iter(sampler), self.batch_size, index, 0, name))
# index += len(sampler)
# if self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# limit_len = max(len(ds) for ds in self.sampler)
# else:
# limit_len = max(len(ds) for _, ds in self.sampler.items())
# elif self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# limit_len = min(len(ds) for ds in self.sampler)
# else:
# limit_len = min(len(ds) for _, ds in self.sampler.items())
# else:
# limit_len = 0
# # 最后一个批次的大小
# last_batch_size = limit_len % self.batch_size
#
# while True:
# # 全部采样完,退出
# if len(samplers) == 0:
# break
# batch, flag = [], False
# # sampler_len代表已经取出来的数据个数
# sampler, batch_size, index, sampler_len, name = samplers.pop(0)
# for _ in range(batch_size):
# try:
# batch.append(index + next(sampler))
# sampler_len += 1
# except StopIteration:
# flag = True
# # ds_ratio为None第一种情况删除掉采样完的数据即可。
# if self.ds_ratio == 'pad_to_most' and sampler_len < limit_len:
# # 重置sampler并取足一个batch数据
# sampler = iter(self.sampler[name])
# # 由于batch_size一定小于等于ds的长度故能够取足一个batch_size的数据
# for _ in range(batch_size-len(batch)):
# batch.append(next(sampler) + index)
# sampler_len += 1
# break
#
# # ds_ratio不为None情况
# # 两种情况会触发一下逻辑1.truncate_to_least时最短的数据集最后一个batch大小不等于batch_size时
# # 其他较长的数据集的最后一个batch长度会较长2. pad_to_most最长的数据集最后一个batch不等于batch_size时较短数据集最后一个
# # batch长度会较长
# if limit_len != 0 and limit_len < sampler_len:
# batch = batch[:last_batch_size]
# # ds_ratio为任意情况下 没有取完所有数据,则添加到队列尾部
# elif (limit_len == 0 and flag == False) or limit_len > sampler_len:
# samplers.append((sampler, batch_size, index, sampler_len, name))
# if len(batch) == batch_size:
# yield batch
# elif len(batch) > 0 and not self.drop_last:
# yield batch
#
# def __len__(self) -> int:
# lens = 0
# max_len, ds_len = 0, 0
# if self.ds_ratio == 'truncate_to_least':
# if isinstance(self.sampler, List):
# max_len = min(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = min(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
#
# elif self.ds_ratio == 'pad_to_most':
# if isinstance(self.sampler, List):
# max_len = max(len(sampler) for sampler in self.sampler)
# ds_len = len(self.sampler)
# elif isinstance(self.sampler, Dict):
# max_len = max(len(sampler) for _, sampler in self.sampler.items())
# for _, _ in self.sampler.items():
# ds_len += 1
# if self.ds_ratio is None:
# if isinstance(self.sampler, List):
# for i in range(len(self.sampler)):
# sampler = self.sampler[i]
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# elif isinstance(self.sampler, Dict):
# for name, sampler in self.sampler.items():
# if self.drop_last:
# lens += len(sampler) // self.batch_size
# else:
# lens += (len(sampler) + self.batch_size - 1) // self.batch_size
# else:
# for i in range(ds_len):
# if self.drop_last:
# lens += max_len // self.batch_size
# else:
# lens += (max_len + self.batch_size - 1) // self.batch_size
# return lens
class BucketSampler(Sampler):
r"""
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素
"""
def __init__(self, dataset, num_buckets=10, batch_size=None, seq_len_field_name='seq_len', drop_last=False) -> None:
r"""
:param int num_buckets: bucket的数量
:param int batch_size: batch的大小. 默认为NoneTrainer/Tester在调用BucketSampler时会将该值正确设置如果是非
Trainer/Tester场景使用需要显示传递该值
:param str seq_len_field_name: 对应序列长度的 `field` 的名字
"""
self.dataset = dataset
self.num_buckets = num_buckets
self.batch_size = batch_size
self.seq_len_field_name = seq_len_field_name
def set_batch_size(self, batch_size) -> None:
r"""
:param int batch_size: 每个batch的大小
:return:
"""
self.batch_size = batch_size
def __iter__(self):
if self.batch_size is None:
raise RuntimeError("batch_size is None.")
seq_lens = self.dataset.get_all_fields()[self.seq_len_field_name].content
total_sample_num = len(seq_lens)
bucket_indexes = []
assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets."
num_sample_per_bucket = total_sample_num // self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
bucket_indexes[-1][1] = total_sample_num
sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x: x[1]))
batchs = []
left_init_indexes = []
for b_idx in range(self.num_buckets):
start_idx = bucket_indexes[b_idx][0]
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes) // self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
if (left_init_indexes) != 0:
batchs.append(left_init_indexes)
np.random.shuffle(batchs)
return chain(*batchs)
class ConstTokenNumSampler(Sampler):
"""
尽量保证每个batch的输入token数量是接近的
"""
def __init__(self, dataset, seq_len_field_name: List[int], max_token: int = 4096, max_sentence: int = -1,
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
"""
:param dataset:
:param List[int] seq_len_field_name: 哪个field指示的sample的长度
:param int max_token: 每个batch的最大的token数量
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数在DataParallel场景下会用到
:param int num_bucket: 将数据按长度拆分为num_bucket个bucketbatch中的sample尽量在bucket之中进行组合这样可以减少padding
"""
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
self.dataset = dataset
self.seq_len_field_name = seq_len_field_name
self.num_bucket = num_bucket
self.max_token = max_token
self._max_sentence = max_sentence
self.need_be_multiple_of = need_be_multiple_of
assert len(self.dataset) > self.num_bucket, "The number of samples should be larger than buckets."
seq_len = self.dataset.get_field(self.seq_len_field_name)
self.seq_len = seq_len
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
seq_len_indice.sort(key=lambda x: x[0])
indice_in_buckets = []
if self.num_bucket > 0:
sample_per_bucket = len(seq_len_indice) // self.num_bucket
i = 0
while len(indice_in_buckets) < len(seq_len_indice):
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
i += 1
else:
indice_in_buckets = [seq_len_indice]
self.indice_in_buckets = indice_in_buckets
self.get_new_order()
@property
def max_sentence(self):
if self._max_sentence < 1:
return 100000000
return self._max_sentence
@max_sentence.setter
def max_sentence(self, max_sentence):
self._max_sentence = max_sentence
def get_new_order(self) -> None:
np.random.shuffle(self.indice_in_buckets)
for bucket in self.indice_in_buckets:
np.random.shuffle(bucket)
indices = list(chain(*self.indice_in_buckets))
batches = []
cur_max_len = 0
batch = []
for length, i in indices:
max_len = max(length, cur_max_len)
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
cur_max_len = length
if left_sample != 0:
add_samples = add_samples[:-left_sample]
batch = batch[-left_sample:]
cur_max_len = max(cur_max_len, max(batch))
else:
batch = []
if len(add_samples) == 0:
raise RuntimeError(
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
batches.append(add_samples)
else:
cur_max_len = max_len
batch.append(i)
if batch:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
if left_sample != 0:
add_samples = add_samples[:-left_sample].copy()
if add_samples:
batches.append(add_samples)
np.random.shuffle(batches)
self.batches = batches
def __iter__(self) -> Iterable[int]:
for batch in self.batches:
yield batch
self.get_new_order()
def __len__(self):
return len(self.batches)
class ConstantTokenNumSampler:
"""
尽量保证每个batch的输入token数量是接近的
"""
def __init__(self, seq_len, max_token: List[int] = 4096, max_sentence: int = -1,
need_be_multiple_of: int = 1, num_bucket: int = -1) -> None:
"""
:param List[int] seq_len: list[int], 是每个sample的长度一般可以通过dataset.get_field('seq_len').content传入
:param int max_token: 每个batch的最大的token数量
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数在DataParallel场景下会用到
:param int num_bucket: 将数据按长度拆分为num_bucket个bucketbatch中的sample尽量在bucket之中进行组合这样可以减少padding
"""
assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1
assert len(seq_len) > num_bucket, "The number of samples should be larger than buckets."
self.seq_len = seq_len
self.max_token = max_token
self._max_sentence = max_sentence
self.need_be_multiple_of = need_be_multiple_of
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
seq_len_indice.sort(key=lambda x: x[0])
indice_in_buckets = []
if num_bucket > 0:
sample_per_bucket = len(seq_len_indice) // num_bucket
i = 0
while len(indice_in_buckets) < len(seq_len_indice):
indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket])
i += 1
else:
indice_in_buckets = [seq_len_indice]
self.indice_in_buckets = indice_in_buckets
self.get_new_order()
@property
def max_sentence(self):
if self._max_sentence < 1:
return 100000000
return self._max_sentence
@max_sentence.setter
def max_sentence(self, max_sentence):
self._max_sentence = max_sentence
def get_new_order(self) -> None:
np.random.shuffle(self.indice_in_buckets)
for bucket in self.indice_in_buckets:
np.random.shuffle(bucket)
indices = list(chain(*self.indice_in_buckets))
batches = []
cur_max_len = 0
batch = []
for length, i in indices:
max_len = max(length, cur_max_len)
if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
cur_max_len = length
if left_sample != 0:
add_samples = add_samples[:-left_sample]
batch = batch[-left_sample:]
cur_max_len = max(cur_max_len, max(batch))
else:
batch = []
if len(add_samples) == 0:
raise RuntimeError(
f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
batches.append(add_samples)
else:
cur_max_len = max_len
batch.append(i)
if batch:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
if left_sample != 0:
add_samples = add_samples[:-left_sample].copy()
if add_samples:
batches.append(add_samples)
np.random.shuffle(batches)
self.batches = batches
def __iter__(self) -> Iterable[int]:
for batch in self.batches:
yield batch
self.get_new_order()
def __len__(self):
return len(self.batches)
class SortedSampler(Sampler):
r"""
按照sample的长度进行排序主要在测试的时候使用可以加速测试因为减少了padding
"""
def __init__(self, dataset, seq_len_field_name: str = 'seq_len', descending: bool = True) -> None:
"""
:param str seq_len_field_name: 按哪个field进行排序如果传入的field是数字则直接按照该数字大小排序如果传入的field不是
数字则使用该field的长度进行排序
:param bool descending: 是否降序排列
"""
self.dataset = dataset
self.seq_len_field_name = seq_len_field_name
self.descending = descending
def __iter__(self) -> Iterable[int]:
seq_lens = self.dataset.get_field(self.seq_len_field_name).content
try:
seq_lens = list(map(len, seq_lens))
except:
pass
orders = np.argsort(seq_lens).tolist() # 从小到大的顺序
if self.descending:
orders = orders[::-1]
for order in orders:
yield order
def simple_sort_bucketing(lengths):
r"""
:param lengths: list of int, the lengths of all examples.
:return data: 2-level list
::
[
[index_11, index_12, ...], # bucket 1
[index_21, index_22, ...], # bucket 2
...
]
"""
lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)]
sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1])
# TODO: need to return buckets
return [idx for idx, _ in sorted_lengths]
def k_means_1d(x, k, max_iter=100):
r"""Perform k-means on 1-D data.
:param x: list of int, representing points in 1-D.
:param k: the number of clusters required.
:param max_iter: maximum iteration
:return centroids: numpy array, centroids of the k clusters
assignment: numpy array, 1-D, the bucket id assigned to each example.
"""
sorted_x = sorted(list(set(x)))
x = np.array(x)
if len(sorted_x) < k:
raise ValueError("too few buckets")
gap = len(sorted_x) / k
centroids = np.array([sorted_x[int(x * gap)] for x in range(k)])
assign = None
for i in range(max_iter):
# Cluster Assignment step
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x])
# Move centroids step
new_centroids = np.array([x[assign == k].mean() for k in range(k)])
if (new_centroids == centroids).all():
centroids = new_centroids
break
centroids = new_centroids
return np.array(centroids), assign
def k_means_bucketing(lengths, buckets):
r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
:param lengths: list of int, the length of all samples.
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
threshold for each bucket (This is usually None.).
:return data: 2-level list
::
[
[index_11, index_12, ...], # bucket 1
[index_21, index_22, ...], # bucket 2
...
]
"""
bucket_data = [[] for _ in buckets]
num_buckets = len(buckets)
_, assignments = k_means_1d(lengths, num_buckets)
for idx, bucket_id in enumerate(assignments):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data

View File

@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
:return:
"""
if fn_name is not None:
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."
assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`."
parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None):
return mask
def wait_to_success(fn, no=False):
def wait_filepath(path, exist=True):
"""
等待当 path 的存在状态为 {exist} 时返回
:param path: 待检测的 path
:param exist: True 时表明检测这个 path 存在就返回; False 表明检测到这个 path 不存在 返回
:return:
"""
if isinstance(path, str):
path = Path(path)
assert isinstance(path, Path)
count = 0
while True:
sleep(0.01)
if (no and not fn()) or (not no and fn()):
if path.exists() == exist:
break
count += 1
if count % 1000 == 0:
msg = 'create' if exist else 'delete'
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")
# 这个是因为在分布式文件系统中可能会发生错误rank0下发删除成功后就运行走了但实际的删除需要rank0的机器发送到远程文件系统再去执行这个时候
# 在rank0那里确实已经删除成功了但是在远程文件系统那里这个操作还没完成rank1读取的时候还是读取到存在这个文件
def synchronize_safe_rm(path: Optional[Union[str, Path]]):
"""
这个是因为在分布式文件系统中可能会发生错误rank0下发删除成功后就运行走了但实际的删除需要rank0的机器发送到远程文件系统再去执行这个时候
在rank0那里确实已经删除成功了但是在远程文件系统那里这个操作还没完成rank1读取的时候还是读取到存在这个文件
该函数会保证所有进程都检测到 path 删除之后才退出请保证不同进程上 path 是完全一样的否则会陷入死锁状态
:param path:
:return:
"""
if path is None:
return
if isinstance(path, str):
@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]):
return
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
_recursive_rm(path)
wait_to_success(path.exists, no=True)
wait_filepath(path, exist=False)
def _recursive_rm(path: Path):
@ -643,6 +665,8 @@ def _recursive_rm(path: Path):
def synchronize_mkdir(path: Optional[Union[str, Path]]):
"""
注意该函数是用来创建文件夹如果需要创建一个文件不要使用该函数
该函数会保证所有进程都检测到 path 创建之后才退出请保证不同进程上 path 是完全一样的否则会陷入死锁状态
"""
if path is None:
return
@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
path.mkdir(parents=True, exist_ok=True)
wait_to_success(path.exists)
wait_filepath(path, exist=True)
def get_class_that_defined_method(method):

View File

@ -5,7 +5,6 @@
import os
import json
import sys
import subprocess
from collections import defaultdict

View File

@ -50,8 +50,6 @@ class ConllLoader(Loader):
ConllLoader返回的DataSet的field由传入的headers确定
数据中以"-DOCSTART-"开头的行将被忽略因为该符号在conll 2003中被用为文档分割符
"""
def __init__(self, headers, sep=None, indexes=None, dropna=True):
@ -93,6 +91,7 @@ class ConllLoader(Loader):
class Conll2003Loader(ConllLoader):
r"""
用于读取conll2003任务的数据数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking第四列为ner
数据中以"-DOCSTART-"开头的行将被忽略因为该符号在conll 2003中被用为文档分割符
Example::

View File

@ -85,7 +85,7 @@ class MixModule:
def test_step(self, batch):
raise NotImplementedError
def validate_step(self, batch):
def evaluate_step(self, batch):
raise NotImplementedError
def train(self):

View File

@ -0,0 +1,41 @@
import pytest
import numpy as np
from fastNLP.core.callbacks import TorchGradClipCallback, Callback
from fastNLP import Trainer
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args
class CheckClipCallback(Callback):
def __init__(self, parameters, clip_type, clip_value):
self.parameters = parameters
self.clip_type = clip_type
self.clip_value = clip_value
def on_after_optimizers_step(self, trainer, optimizers):
for param in self.parameters:
if self.clip_type == 'value':
assert param.grad.max().item()<=self.clip_value
else:
assert np.linalg.norm(param.grad.cpu().view(-1).numpy())<=self.clip_value
@pytest.mark.parametrize('accumulation_steps', [1, 3, 5])
@pytest.mark.parametrize('fp16', [True, False])
@pytest.mark.parametrize('clip_type', ['norm', 'value'])
@pytest.mark.parametrize('clip_value', [1, 2])
def test_torch_grad_clip_callback(accumulation_steps, fp16, clip_type, clip_value):
if not torch.cuda.is_available() and fp16:
pytest.skip("No cuda, cannot test fp16.")
device = 'cuda' if fp16 else 'cpu'
kwargs = get_trainer_args(lr=1, device=device)
callbacks = []
callbacks.append(TorchGradClipCallback(clip_value=clip_value, clip_type=clip_type))
callbacks.append(CheckClipCallback(kwargs['model'].parameters(), clip_type, clip_value))
trainer = Trainer(**kwargs, callbacks=callbacks, fp16=fp16)
trainer.run()

View File

@ -0,0 +1,34 @@
import pytest
import numpy as np
from fastNLP.core.callbacks import TorchWarmupCallback, Callback
from fastNLP import Trainer
from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args
class RecordLrCallback(Callback):
def __init__(self):
self.lrs = []
def on_after_optimizers_step(self, trainer, optimizers):
self.lrs.append(trainer.driver.optimizers[0].param_groups[0]['lr'])
@pytest.mark.parametrize('warmup', [5, 0.1])
@pytest.mark.parametrize('schedule', ['constant', 'linear'])
@pytest.mark.parametrize('accumulation_steps', [1, 3, 4])
def test_torch_warmup_callback(warmup, schedule, accumulation_steps):
kwargs = get_trainer_args(lr=0.1, bsz=4)
callback = TorchWarmupCallback(warmup, schedule)
r_callback = RecordLrCallback()
kwargs['callbacks'] = [callback, r_callback]
trainer = Trainer(**kwargs, accumulation_steps=accumulation_steps)
trainer.run()
if schedule == 'linear':
assert kwargs['optimizers'].param_groups[0]['lr'] <= 0.01
elif schedule == 'constant':
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr'])
assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1

View File

@ -1,13 +1,11 @@
import pytest
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from typing import Any
from dataclasses import dataclass
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
from paddle.optimizer import Adam
from paddle.io import DataLoader
@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM
from tests.helpers.utils import magic_argv_env_context
@dataclass
class MNISTTrainPaddleConfig:
class TrainPaddleConfig:
num_labels: int = 10
feature_dimension: int = 784
feature_dimension: int = 10
batch_size: int = 32
batch_size: int = 2
shuffle: bool = True
validate_every = -5
evaluate_every = 2
driver: str = "paddle"
device = "gpu"
@dataclass
class MNISTTrainFleetConfig:
num_labels: int = 10
feature_dimension: int = 784
batch_size: int = 32
shuffle: bool = True
validate_every = -5
@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
validate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None
@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True),
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5)]])
@magic_argv_env_context
def test_trainer_paddle(
driver,
@ -60,38 +36,36 @@ def test_trainer_paddle(
callbacks,
n_epochs=2,
):
trainer_params = TrainerParameters()
trainer_params.model = PaddleNormalModel_Classification_1(
num_labels=MNISTTrainPaddleConfig.num_labels,
feature_dimension=MNISTTrainPaddleConfig.feature_dimension
model = PaddleNormalModel_Classification_1(
num_labels=TrainPaddleConfig.num_labels,
feature_dimension=TrainPaddleConfig.feature_dimension
)
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(6400, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(1000, 10),
batch_size=MNISTTrainPaddleConfig.batch_size,
dataset=PaddleRandomMaxDataset(20, 10),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = train_dataloader
trainer_params.validate_dataloaders = val_dataloader
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every
trainer_params.metrics = {"acc": Accuracy(backend="paddle")}
train_dataloader = train_dataloader
evaluate_dataloaders = val_dataloader
evaluate_every = TrainPaddleConfig.evaluate_every
metrics = {"acc": Accuracy(backend="paddle")}
trainer = Trainer(
model=trainer_params.model,
model=model,
driver=driver,
device=device,
optimizers=trainer_params.optimizers,
train_dataloader=trainer_params.train_dataloader,
validate_dataloaders=trainer_params.validate_dataloaders,
validate_every=trainer_params.validate_every,
input_mapping=trainer_params.input_mapping,
output_mapping=trainer_params.output_mapping,
metrics=trainer_params.metrics,
optimizers=optimizers,
train_dataloader=train_dataloader,
evaluate_dataloaders=evaluate_dataloaders,
evaluate_every=evaluate_every,
input_mapping=None,
output_mapping=None,
metrics=metrics,
n_epochs=n_epochs,
callbacks=callbacks,

View File

@ -117,12 +117,13 @@ class TestSetDistReproDataloader:
"""
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
assert not (replaced_loader is dataloader)
@ -133,12 +134,13 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
assert not (replaced_loader is dataloader)
@ -171,14 +173,15 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle),
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
@ -195,12 +198,13 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
@ -222,11 +226,12 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
assert replaced_loader is dataloader
@ -238,14 +243,15 @@ class TestSetDistReproDataloader:
"""
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader.batch_sampler ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
@ -258,13 +264,14 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader.batch_sampler.sampler ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@ -276,16 +283,17 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
@ -293,7 +301,7 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
dist.barrier()
"""
@ -302,13 +310,14 @@ class TestSetDistReproDataloader:
"""
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader.batch_sampler.sampler ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@ -320,18 +329,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader.batch_sampler.sampler UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@ -349,11 +359,12 @@ class TestSetDistReproDataloader:
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)

View File

@ -1,4 +1,5 @@
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path
@ -56,34 +57,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict):
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2
# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 batch_sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]
# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
# 4. 检查 batch_idx
# TODO
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)
@ -104,21 +128,36 @@ def test_save_and_load_with_randomsampler(only_state_dict):
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2
# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = states["dataloader"]
replaced_loader = load_states["dataloader"]
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
@ -129,60 +168,51 @@ def test_save_and_load_with_randomsampler(only_state_dict):
# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
# 4. 检查 batch_idx
# TODO
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)
def test_save_and_load_state_dict(prepare_test_save_load):
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试save和load函数
TODO optimizer的state_dict为空暂时不测试
"""
try:
path = "dict"
driver1, driver2, dataloader = prepare_test_save_load
driver1.save_model(path)
driver2.load_model(path)
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path)
def test_save_and_load_whole_model(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空暂时不测试
测试 save_model load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict=False)
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
class TestSingleDeviceFunction:
"""
@ -199,13 +229,7 @@ class TestSingleDeviceFunction:
测试能否运行
"""
res = self.driver.unwrap_model()
def test_check_evaluator_mode(self):
"""
这两个函数没有返回值和抛出异常仅检查是否有import错误等影响运行的因素
"""
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")
assert res is self.driver.model
def test_is_distributed(self):
assert self.driver.is_distributed() == False
@ -237,44 +261,55 @@ class TestSetDistReproDataloder:
assert replaced_loader is dataloader
def test_set_dist_repro_dataloader_with_reproducible_true(self):
@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader batch_sampler RandomBatchSampler
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler paddle.io.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
# self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 dist ReproducibleBatchSampler
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
@ -284,16 +319,21 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
@ -303,15 +343,16 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@ -323,11 +364,11 @@ class TestSetDistReproDataloder:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@ -346,9 +387,6 @@ class TestSetDistReproDataloder:
# 加载 num_consumed_samples_array设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
import time
time.sleep(5)
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
@ -357,16 +395,29 @@ class TestSetDistReproDataloder:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states)
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader):
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)

View File

@ -1,31 +0,0 @@
import unittest
import random
from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler
from fastNLP.core.dataset import DataSet
from array import array
import torch
from fastNLP.core.samplers.sampler import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
class SamplerTest(unittest.TestCase):
def test_sequentialsampler(self):
ds = DataSet({'x': [1, 2, 3, 4] * 10})
sqspl = SequentialSampler(ds)
for idx, inst in enumerate(sqspl):
self.assertEqual(idx, inst)
def test_randomsampler(self):
ds = DataSet({'x': [1, 2, 3, 4] * 10})
rdspl = RandomSampler(ds)
ans = [ds[i] for i in rdspl]
self.assertEqual(len(ans), len(ds))
def test_bucketsampler(self):
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len")

View File

@ -1,6 +1,6 @@
import os
from fastNLP.envs.set_env import dump_fastnlp_backend
from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing
from fastNLP.core import synchronize_safe_rm

View File

@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback):
print("on_train_end")
def on_train_epoch_begin(self, trainer):
if trainer.current_epoch_idx >= 1:
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception
raise Exception
print("on_train_epoch_begin")

View File

@ -0,0 +1,68 @@
"""
这个文件主要用于提供测试 callback 时的 Trainer 的参数可以直接使用进行对Trainer进行初始化只需要再额外传入相应的callback就可以运行
"""
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.metrics import Accuracy
if _NEED_IMPORT_TORCH:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
class DataSet:
def __init__(self, num_samples=1000, num_features=10):
g = torch.Generator()
g.manual_seed(1000)
self.data = torch.randn(num_samples, num_features, generator=g)
self.y = self.data.argmax(dim=-1)
def __getitem__(self, item):
return {'x': self.data[item], 'target': self.y[item]}
def __len__(self):
return len(self.data)
class Model(nn.Module):
def __init__(self, num_features=5):
super().__init__()
self.mlps = nn.Sequential(
nn.Linear(num_features, 20),
nn.ReLU(),
nn.Linear(20, 20),
nn.Dropout(p=0.3),
nn.ReLU(),
nn.Linear(20, num_features)
)
def forward(self, x, target):
y = self.mlps(x)
if self.training:
return {'loss': F.cross_entropy(y, target)}
return {'pred': y}
def get_trainer_args(num_features=5, num_samples=20, bsz=4, lr=0.1, n_epochs=5, device=None):
ds = DataSet(num_samples=num_samples, num_features=num_features)
dl = DataLoader(ds, batch_size=bsz)
model = Model(num_features=num_features)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
kwargs = {
'model': model,
'driver': 'torch',
'device': device,
'optimizers': optimizer,
'train_dataloader': dl,
'evaluate_dataloaders': dl,
'metrics': {'acc': Accuracy()},
'n_epochs': n_epochs
}
return kwargs

View File

@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer):
x = self(x)
return {"loss": self.loss_fn(x, y)}
def validate_step(self, x, y):
def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}