From 9d50e99bfb43581871263133184ac7f28f9e4f17 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 06:57:31 +0000 Subject: [PATCH 1/8] =?UTF-8?q?=E4=BF=AE=E6=94=B9evaluate=5Fdataloader?= =?UTF-8?q?=E7=9A=84=E6=8A=A5=E9=94=99=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 4ff8ba80..f26d841c 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -219,10 +219,10 @@ class Trainer(TrainerEventTrigger): """ 设置内部的 Evaluator """ if metrics is None and evaluate_dataloaders is not None: - raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") + raise ValueError("You have set 'evaluate_dataloader' but forget to set 'metrics'.") if metrics is not None and evaluate_dataloaders is None: - raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") + raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloader'.") self.evaluator = None self.monitor = monitor From 7c70874b4a0424af8d157ad4ed43ededee464b61 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 15 Apr 2022 16:04:43 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=88=A0=E9=99=A4core.sampelrs.sampler.py?= =?UTF-8?q?=EF=BC=9B=E5=A2=9E=E5=8A=A0torch=E7=9A=84clipgradient=E5=92=8Cw?= =?UTF-8?q?armupcallback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/__init__.py | 6 +- .../callbacks/torch_callbacks/__init__.py | 8 + .../torch_grad_clip_callback.py | 52 ++ .../torch_lr_sched_callback.py | 58 ++ fastNLP/core/drivers/driver.py | 2 +- .../drivers/torch_driver/single_device.py | 8 +- fastNLP/core/samplers/__init__.py | 6 - fastNLP/core/samplers/sampler.py | 728 ------------------ fastNLP/io/loader/conll.py | 3 +- .../callbacks/torch_callbacks/__init__.py | 0 .../test_torch_grad_clip_callback.py | 41 + .../test_torch_warmup_callback.py | 34 + tests/core/samplers/test_sampler.py | 31 - .../prepare_trainer_args_for_torch_test.py | 68 ++ 14 files changed, 275 insertions(+), 770 deletions(-) create mode 100644 fastNLP/core/callbacks/torch_callbacks/__init__.py create mode 100644 fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py create mode 100644 fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py delete mode 100644 fastNLP/core/samplers/sampler.py create mode 100644 tests/core/callbacks/torch_callbacks/__init__.py create mode 100644 tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py create mode 100644 tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py delete mode 100644 tests/core/samplers/test_sampler.py create mode 100644 tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index fc5d9d5b..58de0319 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -11,7 +11,10 @@ __all__ = [ 'RichCallback', "LRSchedCallback", 'LoadBestModelCallback', - "EarlyStopCallback" + "EarlyStopCallback", + + "TorchWarmupCallback", + "TorchGradClipCallback" ] @@ -23,4 +26,5 @@ from .progress_callback import choose_progress_callback, ProgressCallback, RichC from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback +from .torch_callbacks import * diff --git a/fastNLP/core/callbacks/torch_callbacks/__init__.py b/fastNLP/core/callbacks/torch_callbacks/__init__.py new file mode 100644 index 00000000..1cadd7f6 --- /dev/null +++ b/fastNLP/core/callbacks/torch_callbacks/__init__.py @@ -0,0 +1,8 @@ +__all__ = [ + 'TorchWarmupCallback', + 'TorchGradClipCallback' +] + + +from .torch_lr_sched_callback import TorchWarmupCallback +from .torch_grad_clip_callback import TorchGradClipCallback \ No newline at end of file diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py new file mode 100644 index 00000000..d5104a26 --- /dev/null +++ b/fastNLP/core/callbacks/torch_callbacks/torch_grad_clip_callback.py @@ -0,0 +1,52 @@ +__all__ = [ + 'TorchGradClipCallback' +] +from ..callback import Callback + + +class TorchGradClipCallback(Callback): + def __init__(self, clip_value=1, clip_type='norm', parameters=None): + r""" + 在每次 optimizer update 之前将 parameter 进行 clip + + :param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 + :param str clip_type: 支持'norm', 'value' + 两种:: + + 1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] + + 2 'value', 将gradient限制在[-clip_value, clip_value], + 小于-clip_value的gradient被赋值为-clip_value; + 大于clip_value的gradient被赋值为clip_value. + :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 + 如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 + """ + super().__init__() + + from torch import nn + if clip_type == 'norm': + self.clip_fun = nn.utils.clip_grad_norm_ + elif clip_type == 'value': + self.clip_fun = nn.utils.clip_grad_value_ + else: + raise ValueError("Only supports `norm` or `value` right now.") + if parameters is not None: + self.parameters = list(parameters) + else: + self.parameters = None + self.clip_value = clip_value + + def on_after_trainer_initialized(self, trainer, driver): + assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \ + f"related drivers for now." + parameters = [] + for optimizer in trainer.driver.optimizers: + for param_group in optimizer.param_groups: + parameters.extend(param_group['params']) + self.parameters = parameters + assert len(self.parameters), "There is no parameters need to be clipped." + + def on_before_optimizers_step(self, trainer, optimizers): + for optimizer in trainer.driver.optimizers: + trainer.driver.grad_scaler.unscale_(optimizer) + self.clip_fun(self.parameters, self.clip_value) diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py new file mode 100644 index 00000000..3d428d47 --- /dev/null +++ b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py @@ -0,0 +1,58 @@ +__all__ = [ + 'TorchWarmupCallback' +] +import math + +from ..callback import Callback + + +class TorchWarmupCallback(Callback): + def __init__(self, warmup=0.1, schedule='constant'): + r""" + 调整 learning rate 的 callback 。仅在实际发生参数更新的情况下 + + :param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, + 如0.1, 则前10%的step是按照schedule策略调整learning rate。 + :param str schedule: 以哪种方式调整。 + linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后warmup的step下降到0; + constant前warmup的step上升到指定learning rate,后面的step保持learning rate. + """ + super().__init__() + self.warmup = max(warmup, 0.) + + self.initial_lrs = [] # 存放param_group的learning rate + if schedule == 'constant': + self.get_lr = self._get_constant_lr + elif schedule == 'linear': + self.get_lr = self._get_linear_lr + else: + raise RuntimeError("Only support 'linear', 'constant'.") + + def _get_constant_lr(self, progress): + if progress 1: + self.warmup = self.warmup / self.t_steps + self.t_steps = max(2, self.t_steps) # 不能小于2 + # 防止 t_steps 不能整除 accumulation_steps + self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps + # 获取param_group的初始learning rate + for optimizer in trainer.driver.optimizers: + for group in optimizer.param_groups: + self.initial_lrs.append(group['lr']) + + def on_before_optimizers_step(self, trainer, optimizers): + # 这里需要加 accumulation_steps 是防止 lr 从 0 开始 + progress = (trainer.global_forward_batches + trainer.accumulation_steps) / self.t_steps + for optimizer in trainer.driver.optimizers: + for lr, group in zip(self.initial_lrs, optimizer.param_groups): + group['lr'] = lr * self.get_lr(progress) diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 0ef7f053..06547516 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -129,7 +129,7 @@ class Driver(ABC): @property def optimizers(self) -> List: r""" - 如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimzer,我们会使用一个 List 将其 + 如下所示,driver 返回的 optimizers 一定是一个 List,如果用户直接向 Trainer 传入一个单独的 optimizer,我们会使用一个 List 将其 包裹; :return: List[optimizer0, optimizer1, optimizer2, ...] diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index adc61bd1..99ba754e 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -37,7 +37,12 @@ class TorchSingleDriver(TorchDriver): super(TorchSingleDriver, self).__init__(model, fp16=fp16, **kwargs) if device is None: - raise ValueError("Parameter `device` can not be None in `TorchSingleDriver`.") + logger.debug("device is not set, fastNLP will try to automatically get it.") + try: + device = next(model.parameters()).device + assert isinstance(device, torch.device) + except: + raise ValueError("fastNLP cannot get device automatically, please set device explicitly.") self.model_device = device @@ -70,6 +75,7 @@ class TorchSingleDriver(TorchDriver): return self.model, model.forward else: + # TODO 这种直接调用模型某个接口的方法无法触发hook,也许需要做一个warning,如果用户有钩子,提醒他train_step无法触发。 if hasattr(self.model, fn): fn = getattr(self.model, fn) if not callable(fn): diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index 61433e8e..edc1f891 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -1,9 +1,4 @@ __all__ = [ - 'BucketSampler', - 'SortedSampler', - 'ConstTokenNumSampler', - 'ConstantTokenNumSampler', - 'MixSampler', 'DopedSampler', 'MixSequentialSampler', @@ -26,7 +21,6 @@ __all__ = [ "re_instantiate_sampler" ] -from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler diff --git a/fastNLP/core/samplers/sampler.py b/fastNLP/core/samplers/sampler.py deleted file mode 100644 index 89751884..00000000 --- a/fastNLP/core/samplers/sampler.py +++ /dev/null @@ -1,728 +0,0 @@ -r""" -sampler 子类实现了 fastNLP 所需的各种采样器。 -""" - -__all__ = [ - "BucketSampler", - "SortedSampler", - 'ConstTokenNumSampler', - "ConstantTokenNumSampler", -] - -from itertools import chain -from typing import List, Iterable - -import numpy as np - -from fastNLP.envs.imports import _NEED_IMPORT_TORCH - -if _NEED_IMPORT_TORCH: - from torch.utils.data import Sampler -else: - from fastNLP.core.utils.dummy_class import DummyClass as Sampler - -# class DopedSampler(Sampler): -# """ -# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 -# """ -# -# def __init__(self, dataset: Union[List, Dict], batch_size: int = None, -# sampler: Union[List[Sampler], Dict[str, Sampler]] = None, -# ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, drop_last: bool = False) -> None: -# if batch_size <= 0: -# raise ValueError("batch_size should be a positive integer value, " -# "but got batch_size={}".format(batch_size)) -# if not isinstance(drop_last, bool): -# raise ValueError("drop_last should be a boolean value, but got " -# "drop_last={}".format(drop_last)) -# self.batch_size = batch_size -# self.drop_last = drop_last -# self.ds_ratio = ds_ratio -# if sampler is None: -# if isinstance(dataset, List): -# self.sampler = [SequentialSampler(ds) for ds in dataset] -# elif isinstance(dataset, Dict): -# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} -# -# elif isinstance(sampler, List): -# if len(sampler) != len(dataset): -# raise ValueError("the length of sampler != the length of sampler") -# self.sampler = sampler -# else: -# self.sampler = sampler -# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None: -# self.ds_ratio = ds_ratio -# elif isinstance(ds_ratio, List): -# if not all(item >= 0 for item in ds_ratio): -# raise ValueError("batch_size should be a positive integer value, " -# "but got batch_size={}".format(ds_ratio)) -# self.ds_ratio = ds_ratio -# else: -# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None") -# -# def __iter__(self): -# samplers, index = [], 0 -# if isinstance(self.sampler, List): -# for idx, sampler in enumerate(self.sampler): -# samplers.append((iter(sampler), self.batch_size, index, 0, idx)) -# index += len(sampler) -# elif isinstance(self.sampler, Dict): -# for name, sampler in self.sampler.items(): -# samplers.append((iter(sampler), self.batch_size, index, 0, name)) -# index += len(sampler) -# -# def __len__(self): -# lens = 0 -# max_len, ds_len = 0, 0 -# if self.ds_ratio == 'truncate_to_least': -# if isinstance(self.sampler, List): -# max_len = min(len(sampler) for sampler in self.sampler) -# ds_len = len(self.sampler) -# elif isinstance(self.sampler, Dict): -# max_len = min(len(sampler) for _, sampler in self.sampler.items()) -# for _, _ in self.sampler.items(): -# ds_len += 1 -# -# elif self.ds_ratio == 'pad_to_most': -# if isinstance(self.sampler, List): -# max_len = max(len(sampler) for sampler in self.sampler) -# ds_len = len(self.sampler) -# elif isinstance(self.sampler, Dict): -# max_len = max(len(sampler) for _, sampler in self.sampler.items()) -# for _, _ in self.sampler.items(): -# ds_len += 1 -# -# if self.ds_ratio is None: -# if isinstance(self.sampler, List): -# for i in range(len(self.sampler)): -# sampler = self.sampler[i] -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# elif isinstance(self.sampler, Dict): -# for name, sampler in self.sampler.items(): -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# elif self.ds_ratio == 'truncate_to_least' or self.ds_ratio == 'pad_to_most': -# for i in range(ds_len): -# if self.drop_last: -# lens += max_len // self.batch_size -# else: -# lens += (max_len + self.batch_size - 1) // self.batch_size -# return lens -# -# def demo(self): -# indexes = np.array([0]*self.batch_size + [1]*self.batch_size + [2]*self.batch_size) -# shift = np.array([0]*self.batch_size + [len(ds1)]*self.batch_size + [len(ds1)+len(ds2)]*self.batch_size) -# buffer = np.zeros(self.batch_size*self.num_ds, dtype=int) -# select_sampler = np.random.randint(0, self.batch_size*self.num_ds, num_sample=self.batch_size) -# select_indices = buffer[select_sampler] + shift[select_sampler] -# num_1 = (indexes[select_sampler]==0).sum() -# - - -# class MixSequentialSampler(Sampler): -# """ -# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表顺序采样并返回index,只有处理了上一个dataset才会处理下一个。 -# """ -# -# def __init__(self, dataset: Union[List, Dict], batch_size: int = None, -# sampler: Union[List[Sampler], Dict[str, Sampler], None] = None, -# drop_last: bool = False) -> None: -# """ -# -# :param dataset: 实现了__getitem__和__len__的数据容器列表 -# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset -# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 -# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size -# """ -# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, -# if isinstance(dataset, Dict) and isinstance(sampler, List): -# raise ValueError(f"{sampler} must be dict") -# -# # 判断batch_size是否大于等于0 -# if batch_size <= 0: -# raise ValueError("batch_size should be a positive integer value, " -# "but got batch_size={}".format(batch_size)) -# -# if not isinstance(drop_last, bool): -# raise ValueError("drop_last should be a boolean value, but got " -# "drop_last={}".format(drop_last)) -# self.batch_size = batch_size -# self.drop_last = drop_last -# if sampler is None: -# if isinstance(dataset, List): -# self.sampler = [SequentialSampler(ds) for ds in dataset] -# elif isinstance(dataset, Dict): -# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} -# elif isinstance(sampler, List): -# if len(sampler) != len(dataset): -# raise ValueError("the length of sampler != the length of sampler") -# self.sampler = sampler -# -# def __iter__(self) -> Iterable[List[int]]: -# """ -# 按照dataset的顺序采样,打包成一个batch后返回 -# :return: -# """ -# index = 0 -# batch = [] -# if isinstance(self. sampler, List): -# for i in range(len(self.sampler)): -# sampler = self.sampler[i] -# for idx in sampler: -# batch.append(idx + index) -# if len(batch) == self.batch_size: -# yield batch -# batch = [] -# if len(batch) > 0 and not self.drop_last: -# yield batch -# batch = [] -# index += len(sampler) -# elif isinstance(self.sampler, Dict): -# for name, sampler in self.sampler.items(): -# for idx in sampler: -# batch.append(idx + index) -# if len(batch) == self.batch_size: -# yield batch -# batch = [] -# if len(batch) > 0 and not self.drop_last: -# yield batch -# batch = [] -# index += len(sampler) -# -# def __len__(self) -> int: -# lens = 0 -# if isinstance(self.sampler, List): -# for i in range(len(self.sampler)): -# sampler = self.sampler[i] -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# elif isinstance(self.sampler, Dict): -# for _, sampler in self.sampler.items(): -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# return lens - - -# class PollingSampler(Sampler): -# """ -# 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表轮流采样并返回index,处理了上个dataset的一个batch后会处理下一个。 -# """ -# -# def __init__(self, dataset: Union[List, Dict], batch_size: int = 16, -# sampler: Union[List[Sampler], Dict[str, Sampler]] = None, -# drop_last: bool = False, ds_ratio="pad_to_most") -> None: -# """ -# -# :param dataset: 实现了__getitem__和__len__的数据容器列表 -# :param batch_size: 对应dataset的批次大小,可以为list或者为int,当为int时默认所有dataset -# :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 -# :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size -# :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, -# 以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 -# """ -# # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, -# if isinstance(dataset, Dict) and isinstance(sampler, List): -# raise ValueError(f"{sampler} must be dict") -# if isinstance(dataset, List) and isinstance(sampler, Dict): -# raise ValueError(f"{sampler} must be list") -# # 判断batch_size是否大于等于0 -# if batch_size <= 0: -# raise ValueError("batch_size should be a positive integer value, " -# "but got batch_size={}".format(batch_size)) -# -# if not isinstance(drop_last, bool): -# raise ValueError("drop_last should be a boolean value, but got " -# "drop_last={}".format(drop_last)) -# -# self.batch_size = batch_size -# self.drop_last = drop_last -# if sampler is None: -# if isinstance(dataset, List): -# self.sampler = [SequentialSampler(ds) for ds in dataset] -# elif isinstance(dataset, Dict): -# self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} -# -# elif isinstance(sampler, List): -# if len(sampler) != len(dataset): -# raise ValueError("the length of sampler != the length of sampler") -# self.sampler = sampler -# else: -# self.sampler = sampler -# if ds_ratio == 'pad_to_most' or ds_ratio == 'truncate_to_least' or ds_ratio is None: -# self.ds_ratio = ds_ratio -# else: -# raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None") -# -# def __iter__(self) -> Iterable[List[int]]: -# # index是数据集下标基址, pointer指向数据集列表的某个数据集 -# index, pointer, samplers, flag = 0, 0, [], False -# -# if isinstance(self.sampler, List): -# for idx, sampler in enumerate(self.sampler): -# samplers.append((iter(sampler), self.batch_size, index, 0, idx)) -# index += len(sampler) -# elif isinstance(self.sampler, Dict): -# for name, sampler in self.sampler.items(): -# samplers.append((iter(sampler), self.batch_size, index, 0, name)) -# index += len(sampler) -# if self.ds_ratio == 'pad_to_most': -# if isinstance(self.sampler, List): -# limit_len = max(len(ds) for ds in self.sampler) -# else: -# limit_len = max(len(ds) for _, ds in self.sampler.items()) -# elif self.ds_ratio == 'truncate_to_least': -# if isinstance(self.sampler, List): -# limit_len = min(len(ds) for ds in self.sampler) -# else: -# limit_len = min(len(ds) for _, ds in self.sampler.items()) -# else: -# limit_len = 0 -# # 最后一个批次的大小 -# last_batch_size = limit_len % self.batch_size -# -# while True: -# # 全部采样完,退出 -# if len(samplers) == 0: -# break -# batch, flag = [], False -# # sampler_len代表已经取出来的数据个数 -# sampler, batch_size, index, sampler_len, name = samplers.pop(0) -# for _ in range(batch_size): -# try: -# batch.append(index + next(sampler)) -# sampler_len += 1 -# except StopIteration: -# flag = True -# # ds_ratio为None,第一种情况,删除掉采样完的数据即可。 -# if self.ds_ratio == 'pad_to_most' and sampler_len < limit_len: -# # 重置sampler,并取足一个batch数据 -# sampler = iter(self.sampler[name]) -# # 由于batch_size一定小于等于ds的长度,故能够取足一个batch_size的数据 -# for _ in range(batch_size-len(batch)): -# batch.append(next(sampler) + index) -# sampler_len += 1 -# break -# -# # ds_ratio不为None情况 -# # 两种情况会触发一下逻辑:1.truncate_to_least时,最短的数据集最后一个batch大小不等于batch_size时, -# # 其他较长的数据集的最后一个batch长度会较长;2. pad_to_most,最长的数据集最后一个batch不等于batch_size时,较短数据集最后一个 -# # batch长度会较长 -# if limit_len != 0 and limit_len < sampler_len: -# batch = batch[:last_batch_size] -# # ds_ratio为任意情况下, 没有取完所有数据,则添加到队列尾部 -# elif (limit_len == 0 and flag == False) or limit_len > sampler_len: -# samplers.append((sampler, batch_size, index, sampler_len, name)) -# if len(batch) == batch_size: -# yield batch -# elif len(batch) > 0 and not self.drop_last: -# yield batch -# -# def __len__(self) -> int: -# lens = 0 -# max_len, ds_len = 0, 0 -# if self.ds_ratio == 'truncate_to_least': -# if isinstance(self.sampler, List): -# max_len = min(len(sampler) for sampler in self.sampler) -# ds_len = len(self.sampler) -# elif isinstance(self.sampler, Dict): -# max_len = min(len(sampler) for _, sampler in self.sampler.items()) -# for _, _ in self.sampler.items(): -# ds_len += 1 -# -# elif self.ds_ratio == 'pad_to_most': -# if isinstance(self.sampler, List): -# max_len = max(len(sampler) for sampler in self.sampler) -# ds_len = len(self.sampler) -# elif isinstance(self.sampler, Dict): -# max_len = max(len(sampler) for _, sampler in self.sampler.items()) -# for _, _ in self.sampler.items(): -# ds_len += 1 -# if self.ds_ratio is None: -# if isinstance(self.sampler, List): -# for i in range(len(self.sampler)): -# sampler = self.sampler[i] -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# elif isinstance(self.sampler, Dict): -# for name, sampler in self.sampler.items(): -# if self.drop_last: -# lens += len(sampler) // self.batch_size -# else: -# lens += (len(sampler) + self.batch_size - 1) // self.batch_size -# else: -# for i in range(ds_len): -# if self.drop_last: -# lens += max_len // self.batch_size -# else: -# lens += (max_len + self.batch_size - 1) // self.batch_size -# return lens - - -class BucketSampler(Sampler): - r""" - 带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 - """ - - def __init__(self, dataset, num_buckets=10, batch_size=None, seq_len_field_name='seq_len', drop_last=False) -> None: - r""" - - :param int num_buckets: bucket的数量 - :param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非 - Trainer/Tester场景使用,需要显示传递该值 - :param str seq_len_field_name: 对应序列长度的 `field` 的名字 - """ - self.dataset = dataset - self.num_buckets = num_buckets - self.batch_size = batch_size - self.seq_len_field_name = seq_len_field_name - - def set_batch_size(self, batch_size) -> None: - r""" - - :param int batch_size: 每个batch的大小 - :return: - """ - self.batch_size = batch_size - - def __iter__(self): - if self.batch_size is None: - raise RuntimeError("batch_size is None.") - seq_lens = self.dataset.get_all_fields()[self.seq_len_field_name].content - total_sample_num = len(seq_lens) - - bucket_indexes = [] - assert total_sample_num >= self.num_buckets, "The number of samples is smaller than the number of buckets." - num_sample_per_bucket = total_sample_num // self.num_buckets - for i in range(self.num_buckets): - bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) - bucket_indexes[-1][1] = total_sample_num - - sorted_seq_lens = list(sorted([(idx, seq_len) for - idx, seq_len in zip(range(total_sample_num), seq_lens)], - key=lambda x: x[1])) - - batchs = [] - - left_init_indexes = [] - for b_idx in range(self.num_buckets): - start_idx = bucket_indexes[b_idx][0] - end_idx = bucket_indexes[b_idx][1] - sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] - left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) - num_batch_per_bucket = len(left_init_indexes) // self.batch_size - np.random.shuffle(left_init_indexes) - for i in range(num_batch_per_bucket): - batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size]) - left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:] - if (left_init_indexes) != 0: - batchs.append(left_init_indexes) - np.random.shuffle(batchs) - - return chain(*batchs) - - -class ConstTokenNumSampler(Sampler): - """ - 尽量保证每个batch的输入token数量是接近的。 - - """ - - def __init__(self, dataset, seq_len_field_name: List[int], max_token: int = 4096, max_sentence: int = -1, - need_be_multiple_of: int = 1, num_bucket: int = -1) -> None: - """ - - :param dataset: - :param List[int] seq_len_field_name: 哪个field指示的sample的长度 - :param int max_token: 每个batch的最大的token数量 - :param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 - :param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 - :param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 - """ - assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1 - self.dataset = dataset - self.seq_len_field_name = seq_len_field_name - self.num_bucket = num_bucket - self.max_token = max_token - self._max_sentence = max_sentence - self.need_be_multiple_of = need_be_multiple_of - - assert len(self.dataset) > self.num_bucket, "The number of samples should be larger than buckets." - seq_len = self.dataset.get_field(self.seq_len_field_name) - self.seq_len = seq_len - seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] - seq_len_indice.sort(key=lambda x: x[0]) - indice_in_buckets = [] - if self.num_bucket > 0: - sample_per_bucket = len(seq_len_indice) // self.num_bucket - i = 0 - while len(indice_in_buckets) < len(seq_len_indice): - indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket]) - i += 1 - else: - indice_in_buckets = [seq_len_indice] - self.indice_in_buckets = indice_in_buckets - self.get_new_order() - - @property - def max_sentence(self): - if self._max_sentence < 1: - return 100000000 - return self._max_sentence - - @max_sentence.setter - def max_sentence(self, max_sentence): - self._max_sentence = max_sentence - - def get_new_order(self) -> None: - np.random.shuffle(self.indice_in_buckets) - for bucket in self.indice_in_buckets: - np.random.shuffle(bucket) - indices = list(chain(*self.indice_in_buckets)) - batches = [] - cur_max_len = 0 - batch = [] - for length, i in indices: - max_len = max(length, cur_max_len) - if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence: - left_sample = len(batch) % self.need_be_multiple_of - add_samples = batch.copy() - cur_max_len = length - if left_sample != 0: - add_samples = add_samples[:-left_sample] - batch = batch[-left_sample:] - cur_max_len = max(cur_max_len, max(batch)) - else: - batch = [] - if len(add_samples) == 0: - raise RuntimeError( - f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") - batches.append(add_samples) - else: - cur_max_len = max_len - batch.append(i) - if batch: - left_sample = len(batch) % self.need_be_multiple_of - add_samples = batch.copy() - if left_sample != 0: - add_samples = add_samples[:-left_sample].copy() - if add_samples: - batches.append(add_samples) - np.random.shuffle(batches) - self.batches = batches - - def __iter__(self) -> Iterable[int]: - for batch in self.batches: - yield batch - self.get_new_order() - - def __len__(self): - return len(self.batches) - - -class ConstantTokenNumSampler: - """ - 尽量保证每个batch的输入token数量是接近的。 - - """ - - def __init__(self, seq_len, max_token: List[int] = 4096, max_sentence: int = -1, - need_be_multiple_of: int = 1, num_bucket: int = -1) -> None: - """ - - :param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入 - :param int max_token: 每个batch的最大的token数量 - :param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 - :param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 - :param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 - """ - assert (max_sentence != -1 and max_sentence >= need_be_multiple_of) or max_sentence < 1 - assert len(seq_len) > num_bucket, "The number of samples should be larger than buckets." - self.seq_len = seq_len - self.max_token = max_token - self._max_sentence = max_sentence - self.need_be_multiple_of = need_be_multiple_of - seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] - seq_len_indice.sort(key=lambda x: x[0]) - indice_in_buckets = [] - if num_bucket > 0: - sample_per_bucket = len(seq_len_indice) // num_bucket - i = 0 - while len(indice_in_buckets) < len(seq_len_indice): - indice_in_buckets.append(seq_len_indice[i * sample_per_bucket:(i + 1) * sample_per_bucket]) - i += 1 - else: - indice_in_buckets = [seq_len_indice] - self.indice_in_buckets = indice_in_buckets - self.get_new_order() - - @property - def max_sentence(self): - if self._max_sentence < 1: - return 100000000 - return self._max_sentence - - @max_sentence.setter - def max_sentence(self, max_sentence): - self._max_sentence = max_sentence - - def get_new_order(self) -> None: - np.random.shuffle(self.indice_in_buckets) - for bucket in self.indice_in_buckets: - np.random.shuffle(bucket) - indices = list(chain(*self.indice_in_buckets)) - batches = [] - cur_max_len = 0 - batch = [] - for length, i in indices: - max_len = max(length, cur_max_len) - if max_len * (len(batch) + 1) > self.max_token or len(batch) >= self.max_sentence: - left_sample = len(batch) % self.need_be_multiple_of - add_samples = batch.copy() - cur_max_len = length - if left_sample != 0: - add_samples = add_samples[:-left_sample] - batch = batch[-left_sample:] - cur_max_len = max(cur_max_len, max(batch)) - else: - batch = [] - if len(add_samples) == 0: - raise RuntimeError( - f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") - batches.append(add_samples) - else: - cur_max_len = max_len - batch.append(i) - if batch: - left_sample = len(batch) % self.need_be_multiple_of - add_samples = batch.copy() - if left_sample != 0: - add_samples = add_samples[:-left_sample].copy() - if add_samples: - batches.append(add_samples) - np.random.shuffle(batches) - self.batches = batches - - def __iter__(self) -> Iterable[int]: - for batch in self.batches: - yield batch - self.get_new_order() - - def __len__(self): - return len(self.batches) - - -class SortedSampler(Sampler): - r""" - 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) - """ - - def __init__(self, dataset, seq_len_field_name: str = 'seq_len', descending: bool = True) -> None: - """ - - :param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是 - 数字,则使用该field的长度进行排序 - :param bool descending: 是否降序排列 - """ - self.dataset = dataset - self.seq_len_field_name = seq_len_field_name - self.descending = descending - - def __iter__(self) -> Iterable[int]: - seq_lens = self.dataset.get_field(self.seq_len_field_name).content - try: - seq_lens = list(map(len, seq_lens)) - except: - pass - - orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 - if self.descending: - orders = orders[::-1] - for order in orders: - yield order - - -def simple_sort_bucketing(lengths): - r""" - - :param lengths: list of int, the lengths of all examples. - :return data: 2-level list - :: - - [ - [index_11, index_12, ...], # bucket 1 - [index_21, index_22, ...], # bucket 2 - ... - ] - - """ - lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)] - sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1]) - # TODO: need to return buckets - return [idx for idx, _ in sorted_lengths] - - -def k_means_1d(x, k, max_iter=100): - r"""Perform k-means on 1-D data. - - :param x: list of int, representing points in 1-D. - :param k: the number of clusters required. - :param max_iter: maximum iteration - :return centroids: numpy array, centroids of the k clusters - assignment: numpy array, 1-D, the bucket id assigned to each example. - """ - sorted_x = sorted(list(set(x))) - x = np.array(x) - if len(sorted_x) < k: - raise ValueError("too few buckets") - gap = len(sorted_x) / k - - centroids = np.array([sorted_x[int(x * gap)] for x in range(k)]) - assign = None - - for i in range(max_iter): - # Cluster Assignment step - assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x]) - # Move centroids step - new_centroids = np.array([x[assign == k].mean() for k in range(k)]) - if (new_centroids == centroids).all(): - centroids = new_centroids - break - centroids = new_centroids - return np.array(centroids), assign - - -def k_means_bucketing(lengths, buckets): - r"""Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. - - :param lengths: list of int, the length of all samples. - :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length - threshold for each bucket (This is usually None.). - :return data: 2-level list - :: - - [ - [index_11, index_12, ...], # bucket 1 - [index_21, index_22, ...], # bucket 2 - ... - ] - - """ - bucket_data = [[] for _ in buckets] - num_buckets = len(buckets) - _, assignments = k_means_1d(lengths, num_buckets) - - for idx, bucket_id in enumerate(assignments): - if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: - bucket_data[bucket_id].append(idx) - return bucket_data diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index e099331f..90045e46 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -50,8 +50,6 @@ class ConllLoader(Loader): ConllLoader返回的DataSet的field由传入的headers确定。 - 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 - """ def __init__(self, headers, sep=None, indexes=None, dropna=True): @@ -93,6 +91,7 @@ class ConllLoader(Loader): class Conll2003Loader(ConllLoader): r""" 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 Example:: diff --git a/tests/core/callbacks/torch_callbacks/__init__.py b/tests/core/callbacks/torch_callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py b/tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py new file mode 100644 index 00000000..7f2016e2 --- /dev/null +++ b/tests/core/callbacks/torch_callbacks/test_torch_grad_clip_callback.py @@ -0,0 +1,41 @@ +import pytest +import numpy as np + +from fastNLP.core.callbacks import TorchGradClipCallback, Callback +from fastNLP import Trainer +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + +from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args + + +class CheckClipCallback(Callback): + def __init__(self, parameters, clip_type, clip_value): + self.parameters = parameters + self.clip_type = clip_type + self.clip_value = clip_value + + def on_after_optimizers_step(self, trainer, optimizers): + for param in self.parameters: + if self.clip_type == 'value': + assert param.grad.max().item()<=self.clip_value + else: + assert np.linalg.norm(param.grad.cpu().view(-1).numpy())<=self.clip_value + + +@pytest.mark.parametrize('accumulation_steps', [1, 3, 5]) +@pytest.mark.parametrize('fp16', [True, False]) +@pytest.mark.parametrize('clip_type', ['norm', 'value']) +@pytest.mark.parametrize('clip_value', [1, 2]) +def test_torch_grad_clip_callback(accumulation_steps, fp16, clip_type, clip_value): + if not torch.cuda.is_available() and fp16: + pytest.skip("No cuda, cannot test fp16.") + device = 'cuda' if fp16 else 'cpu' + kwargs = get_trainer_args(lr=1, device=device) + callbacks = [] + callbacks.append(TorchGradClipCallback(clip_value=clip_value, clip_type=clip_type)) + callbacks.append(CheckClipCallback(kwargs['model'].parameters(), clip_type, clip_value)) + trainer = Trainer(**kwargs, callbacks=callbacks, fp16=fp16) + trainer.run() diff --git a/tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py b/tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py new file mode 100644 index 00000000..6367c458 --- /dev/null +++ b/tests/core/callbacks/torch_callbacks/test_torch_warmup_callback.py @@ -0,0 +1,34 @@ +import pytest +import numpy as np + +from fastNLP.core.callbacks import TorchWarmupCallback, Callback +from fastNLP import Trainer + +from tests.helpers.callbacks.prepare_trainer_args_for_torch_test import get_trainer_args + + +class RecordLrCallback(Callback): + def __init__(self): + self.lrs = [] + + def on_after_optimizers_step(self, trainer, optimizers): + self.lrs.append(trainer.driver.optimizers[0].param_groups[0]['lr']) + + +@pytest.mark.parametrize('warmup', [5, 0.1]) +@pytest.mark.parametrize('schedule', ['constant', 'linear']) +@pytest.mark.parametrize('accumulation_steps', [1, 3, 4]) +def test_torch_warmup_callback(warmup, schedule, accumulation_steps): + kwargs = get_trainer_args(lr=0.1, bsz=4) + callback = TorchWarmupCallback(warmup, schedule) + r_callback = RecordLrCallback() + kwargs['callbacks'] = [callback, r_callback] + trainer = Trainer(**kwargs, accumulation_steps=accumulation_steps) + trainer.run() + + if schedule == 'linear': + assert kwargs['optimizers'].param_groups[0]['lr'] <= 0.01 + elif schedule == 'constant': + assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr']) + + assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1 \ No newline at end of file diff --git a/tests/core/samplers/test_sampler.py b/tests/core/samplers/test_sampler.py deleted file mode 100644 index 63d8e860..00000000 --- a/tests/core/samplers/test_sampler.py +++ /dev/null @@ -1,31 +0,0 @@ -import unittest -import random -from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler -from fastNLP.core.dataset import DataSet -from array import array -import torch - -from fastNLP.core.samplers.sampler import ReproduceBatchSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler -from tests.helpers.datasets.torch_data import TorchNormalDataset - - -class SamplerTest(unittest.TestCase): - - def test_sequentialsampler(self): - ds = DataSet({'x': [1, 2, 3, 4] * 10}) - sqspl = SequentialSampler(ds) - for idx, inst in enumerate(sqspl): - self.assertEqual(idx, inst) - - def test_randomsampler(self): - ds = DataSet({'x': [1, 2, 3, 4] * 10}) - rdspl = RandomSampler(ds) - ans = [ds[i] for i in rdspl] - self.assertEqual(len(ans), len(ds)) - - def test_bucketsampler(self): - data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) - sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len") - - diff --git a/tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py b/tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py new file mode 100644 index 00000000..01544e0a --- /dev/null +++ b/tests/helpers/callbacks/prepare_trainer_args_for_torch_test.py @@ -0,0 +1,68 @@ + +""" +这个文件主要用于提供测试 callback 时的 Trainer 的参数,可以直接使用进行对Trainer进行初始化。只需要再额外传入相应的callback就可以运行 + +""" + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.metrics import Accuracy + + +if _NEED_IMPORT_TORCH: + import torch + from torch import nn + from torch.utils.data import DataLoader + import torch.nn.functional as F + + class DataSet: + def __init__(self, num_samples=1000, num_features=10): + g = torch.Generator() + g.manual_seed(1000) + self.data = torch.randn(num_samples, num_features, generator=g) + self.y = self.data.argmax(dim=-1) + + def __getitem__(self, item): + return {'x': self.data[item], 'target': self.y[item]} + + def __len__(self): + return len(self.data) + + + class Model(nn.Module): + def __init__(self, num_features=5): + super().__init__() + self.mlps = nn.Sequential( + nn.Linear(num_features, 20), + nn.ReLU(), + nn.Linear(20, 20), + nn.Dropout(p=0.3), + nn.ReLU(), + nn.Linear(20, num_features) + ) + + def forward(self, x, target): + y = self.mlps(x) + if self.training: + return {'loss': F.cross_entropy(y, target)} + return {'pred': y} + + +def get_trainer_args(num_features=5, num_samples=20, bsz=4, lr=0.1, n_epochs=5, device=None): + ds = DataSet(num_samples=num_samples, num_features=num_features) + dl = DataLoader(ds, batch_size=bsz) + model = Model(num_features=num_features) + + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + + kwargs = { + 'model': model, + 'driver': 'torch', + 'device': device, + 'optimizers': optimizer, + 'train_dataloader': dl, + 'evaluate_dataloaders': dl, + 'metrics': {'acc': Accuracy()}, + 'n_epochs': n_epochs + } + + return kwargs \ No newline at end of file From 8fc4fd19ffaca481390ebca3ee2ad3f3e2d9fad5 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 08:05:39 +0000 Subject: [PATCH 3/8] small --- fastNLP/envs/set_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 6da62334..4d6fb915 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -5,7 +5,6 @@ import os import json import sys -import subprocess from collections import defaultdict From d0f26c7c3449c4b75dc5a3e980cc91c2218f8db4 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 08:10:01 +0000 Subject: [PATCH 4/8] =?UTF-8?q?=E5=B0=86validate=5Fstep=E6=9B=BF=E6=8D=A2?= =?UTF-8?q?=E4=B8=BAevaluate=5Fstep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/jittor_driver/mpi.py | 13 +- .../drivers/jittor_driver/single_device.py | 64 +-- .../core/drivers/paddle_driver/dist_utils.py | 376 ++++++++++++++++++ fastNLP/core/drivers/paddle_driver/fleet.py | 140 ++++--- .../drivers/paddle_driver/paddle_driver.py | 32 +- .../drivers/paddle_driver/single_device.py | 137 +++---- fastNLP/core/drivers/paddle_driver/utils.py | 80 +--- .../torch_paddle_driver.py | 59 +-- fastNLP/modules/mix_modules/mix_module.py | 2 +- tests/core/controllers/test_trainer_paddle.py | 80 ++-- .../paddle_driver/test_single_device.py | 150 ++++--- tests/helpers/callbacks/helper_callbacks.py | 2 +- tests/helpers/models/paddle_model.py | 2 +- 13 files changed, 683 insertions(+), 454 deletions(-) create mode 100644 fastNLP/core/drivers/paddle_driver/dist_utils.py diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index c467b868..98ac44a0 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Union +from typing import Optional, Union, Callable, Dict, Tuple from .jittor_driver import JittorDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -61,14 +61,11 @@ class JittorMPIDriver(JittorDriver): return self._data_device return self.model_device - def train_step(self, batch): - return self._train_step(batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + pass - def validate_step(self, batch): - return self._validate_step(batch) - - def test_step(self, batch): - return self._test_step(batch) + def get_model_call_fn(self, fn: str) -> Tuple: + pass def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 84bdb28b..695e6ec9 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -1,9 +1,11 @@ -from typing import Dict, Union +from typing import Dict, Union, Tuple, Callable, Optional from .jittor_driver import JittorDriver from fastNLP.core.utils import auto_param_call +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.log import logger if _NEED_IMPORT_JITTOR: import jittor @@ -27,42 +29,6 @@ class JittorSingleDriver(JittorDriver): self.global_rank = 0 self.world_size = 1 - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - model = self.unwrap_model() - self._train_signature_fn = model.execute - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.test_step - else: - self._validate_step = self.model - model = self.unwrap_model() - self._validate_signature_fn = model.execute - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.evaluate_step - else: - self._test_step = self.model - model = self.unwrap_model() - self._test_signature_fn = model.execute - - def train_step(self, batch) -> Dict: - if isinstance(batch, Dict): - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) - else: - return self._train_step(batch) - def step(self): """ jittor optimizers 的step函数可以传入参数loss @@ -80,18 +46,24 @@ class JittorSingleDriver(JittorDriver): for optimizer in self.optimizers: optimizer.zero_grad() - def validate_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._validate_step(batch) + return fn(batch) - def test_step(self, batch): - - if isinstance(batch, Dict): - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def unwrap_model(self): return self.model diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py new file mode 100644 index 00000000..3bfbbd4f --- /dev/null +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -0,0 +1,376 @@ +import io +import pickle +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler +from typing import Any, List + +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 +from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed as dist + if _TORCH_GREATER_EQUAL_1_8: + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + pass + + +from fastNLP.core.utils import apply_to_collection + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified " + "on non-destination ranks." + ) + + +def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): + """ + 从其它 rank gather 东西到 dst rank 。 + + Gathers picklable objects from the whole group in a single process. + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank. (default is 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: Note that this API is not supported when using the NCCL backend. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + gather_objects[dist.get_rank()], + output if dist.get_rank() == 0 else None, + dst=0 + ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + if group is None: + group = DEFAULT_TORCH_GROUP + + if dist.distributed_c10d._rank_not_in_group(group): + return + + # Ensure object_gather_list is specified appopriately. + my_rank = dist.get_rank() + _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + input_tensor, local_size = _object_to_tensor(obj) + group_backend = dist.get_backend(group) + current_device = torch.device("cpu") + is_nccl_backend = group_backend == dist.Backend.NCCL + if is_nccl_backend: + current_device = torch.device('cuda', torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_rank == dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + dist.gather( + input_tensor, + gather_list=output_tensors if my_rank == dst else None, + dst=dst, + group=group, + ) + if my_rank != dst: + return + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) # type: ignore[call-overload] + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size) + + +def _object_to_tensor(obj, device=None): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + if device is not None: + byte_tensor = byte_tensor.to(device) + local_size = local_size.to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size): + buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): + # src rank send to all other ranks + size = torch.LongTensor([0]).to(device) + + if cur_rank == src: + world_size = dist.get_world_size(group=group) + tensor, size = _object_to_tensor(obj) + tensor = tensor.to(device) + size = size.to(device) + + # 首先同步 obj 的 size 的信息; + dist.broadcast(size, src, group=group) + for subrank in range(world_size): + if subrank != src: + dist.send(tensor=tensor, dst=subrank, group=group, tag=tag) + else: + dist.broadcast(size, src, group=group) + tensor = torch.ByteTensor([0] * size).to(device) + dist.recv(tensor=tensor, src=src, group=group, tag=tag) + + return _tensor_to_object(tensor.cpu(), size) + +def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: + """ + 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 + + example: + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 + :param group: + :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 + """ + if group is None: + group = DEFAULT_TORCH_GROUP + if isinstance(obj, torch.Tensor): + objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] + dist.all_gather(objs, obj, group=group) + else: + objs = [None for _ in range(dist.get_world_size(group))] + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + dist.all_gather_object(objs, obj, group=group) + else: + objs = all_gather_object(objs, obj, group=group) + return objs + + +def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): + """ + 将 src 上的 obj 对象广播到其它 rank 上。 + + :param obj: + :param src: + :param device: + :param group: + :return: + """ + if group is None: + group = DEFAULT_TORCH_GROUP + cur_rank = dist.get_rank(group) + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + if cur_rank!=src: + get_obj = [None] + dist.broadcast_object_list(get_obj, src=src, group=group) + return get_obj[0] + else: + dist.broadcast_object_list([obj], src=src, group=group) + return obj + if device is None: + device = torch.cuda.current_device() + + if cur_rank == src: + tensor, size = _object_to_tensor(obj, device=device) + else: + size = torch.LongTensor([0]).to(device) + + dist.broadcast(size, src=src, group=group) + if cur_rank != src: + tensor = torch.empty( + size.int().item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=device + ) + dist.broadcast(tensor, src=src, group=group) + + return _tensor_to_object(tensor, tensor_size=size.item()) + + +def _check_for_nccl_backend(group): + pg = group or dist.distributed_c10d._get_default_group() + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, _ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + dist.is_nccl_available() and + isinstance(pg, dist.ProcessGroupNCCL) + ) + + +def all_gather_object(object_list, obj, group=None): + """ + 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 + + Gathers picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the object + must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + object (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if dist.distributed_c10d._rank_not_in_group(group): + return + if _TORCH_GREATER_EQUAL_1_8: + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + else: + current_device = torch.cuda.current_device() + + input_tensor, local_size = _object_to_tensor(obj, device=current_device) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + return object_list \ No newline at end of file diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 1b29fd07..a083e42c 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,13 +1,12 @@ import os +import shutil from functools import partial -from typing import List, Union, Optional, Dict +from typing import List, Union, Optional, Dict, Tuple, Callable from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( _FleetWrappingModel, - ForwardState, - _MODE_PARAMETER, get_device_from_visible, reset_seed, replace_sampler, @@ -47,8 +46,7 @@ if _NEED_IMPORT_PADDLE: __all__ = [ "PaddleFleetDriver", ] -# if os.path.exists(self.gloo_rendezvous_dir): -# shutil.rmtree(self.gloo_rendezvous_dir) + class PaddleFleetDriver(PaddleDriver): def __init__( self, @@ -104,34 +102,6 @@ class PaddleFleetDriver(PaddleDriver): # 我们就直接将 model_device 置为 None; self._model_device = None - def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(step_fn, batch, signature_fn=signature_fn) - else: - return self._validate_step(batch) - - model = model._layers - if hasattr(model, "train_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - - if hasattr(model, "evaluate_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `evaluate_step` method, which we can not call actually, " - "we will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - - if hasattr(model, "test_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) if self._data_device is not None: @@ -150,8 +120,6 @@ class PaddleFleetDriver(PaddleDriver): self.world_size = None self.global_rank = 0 - self._configured = False # 防止重复调用 configure_ddp() 函数使用 - self._has_setup = False # 防止重复调用 setup() 函数 self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) @@ -173,6 +141,9 @@ class PaddleFleetDriver(PaddleDriver): os.makedirs(name=self.output_from_new_proc, exist_ok=True) self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_fleetwrapped = False # 判断传入的模型是否经过 _has_fleetwrapped 包裹; + def setup(self): """ 在主进程拉起其它子进程,将主进程作为rank 0 @@ -268,17 +239,17 @@ class PaddleFleetDriver(PaddleDriver): dist.barrier() def configure_fleet(self): - if not self._configured and not isinstance(self.model, DataParallel): + if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): self.model = DataParallel( _FleetWrappingModel(self.model), **self._fleet_kwargs ) + self._has_fleetwrapped = True - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) - - self._configured = True + def on_exception(self): + if os.path.exists(self.gloo_rendezvous_dir): + shutil.rmtree(self.gloo_rendezvous_dir) + super().on_exception() @property def world_size(self) -> int: @@ -310,14 +281,39 @@ class PaddleFleetDriver(PaddleDriver): return self._data_device return self.model_device - def train_step(self, batch): - return self._train_step(batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if self._has_fleetwrapped: + return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, + wo_auto_param_call=self.wo_auto_param_call) + else: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) - def validate_step(self, batch): - return self._validate_step(batch) + def get_model_call_fn(self, fn: str) -> Tuple: + model = self.unwrap_model() + if self._has_fleetwrapped: + if hasattr(model, fn): + fn = getattr(model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") + return fn, None + elif fn in {"train_step", "evaluate_step"}: + return model, model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your model.") + else: + if hasattr(model, fn): + logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " + f"the `{fn}` method, which we can not call actually, we will" + " call `forward` function instead of `train_step` and you should note that.") + elif fn not in {"train_step", "evaluate_step"}: + raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " + "`DistributedDataParallel` model, which means that we will only call model.forward " + "function when we are in forward propagation.") - def test_step(self, batch): - return self._test_step(batch) + return self.model, model.forward def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): @@ -406,14 +402,6 @@ class PaddleFleetDriver(PaddleDriver): else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def is_global_zero(self): return self.global_rank == 0 @@ -450,3 +438,45 @@ class PaddleFleetDriver(PaddleDriver): if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " f"not {type(each_optimizer)}.") + + def broadcast_object(self, obj, src:int=0, group=None, **kwargs): + """ + 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 + 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 + + :param obj: obj,可能是 Tensor 或 嵌套类型的数据 + :param int src: source 的 global rank 。 + :param int dst: target 的 global rank,可以是多个目标 rank + :param group: 所属的 group + :param kwargs: + :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 + 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 + """ + return + return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) + + def all_gather(self, obj, group) -> List: + """ + 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 + pickle 进行序列化,接收到之后再反序列化。 + + example: + obj = { + 'a': [1, 1], + 'b': [[1, 2], [1, 2]], + 'c': { + 'd': [1, 2] + } + } + -> + [ + {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, + {'a': 1, 'b':[1, 2], 'c':{'d': 2}} + ] + + :param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 + :param group: + :return: + """ + return + return fastnlp_paddle_all_gather(obj, group=group) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 977eaf2c..37a5e59e 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -71,6 +71,14 @@ class PaddleDriver(Driver): for optimizer in self.optimizers: optimizer.clear_grad() + def backward(self, loss): + self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + @staticmethod def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): r""" @@ -115,28 +123,6 @@ class PaddleDriver(Driver): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " f"not {type(each_optimizer)}.") - def check_evaluator_mode(self, mode: str): - r""" - 因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; - 因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 - 我们应当提醒用户这一行为; - """ - model = self.unwrap_model() - if mode == "validate": - if not hasattr(model, "evaluate_step"): - if hasattr(model, "test_step"): - logger.warning( - "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" - "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" - "'evaluate_step'.") - - else: - if not hasattr(model, "test_step"): - if hasattr(model, "evaluate_step"): - logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" - "'test_step'.") - @staticmethod def tensor_to_numeric(tensor, reduce=None): r""" @@ -268,10 +254,10 @@ class PaddleDriver(Driver): except: # 有可能 batch_size 为 None,就只有损失精度了 pass assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." + states["sampler_states"] = sampler_states else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") - states["sampler_states"] = sampler_states # 2. 保存模型的状态; if should_save_model: diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index f11cb49a..e47360ee 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Callable, Tuple from .paddle_driver import PaddleDriver from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible @@ -11,16 +11,19 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) +from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.samplers import ( ReproducibleBatchSampler, RandomBatchSampler, ReproducibleSampler, + RandomSampler, re_instantiate_sampler, ) from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: import paddle + from paddle import DataParallel from paddle.fluid.reader import _DatasetKind __all__ = [ @@ -28,109 +31,57 @@ __all__ = [ ] class PaddleSingleDriver(PaddleDriver): - def __init__(self, model, device: str, fp16: Optional[bool] = False, **kwargs): + def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): + if isinstance(model, DataParallel): + raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") + + cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None) + if cuda_visible_devices == "": + device = "cpu" + logger.info("You have set `CUDA_VISIBLE_DEVICES` to '' in system environment variable, and we are gonna to" + "use `cpu` instead of `gpu` device.") + super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) if device is None: raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") + if device != "cpu": + if isinstance(device, int): + device_id = device + else: + device_id = get_paddle_device_id(device) + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] self.model_device = get_paddle_gpu_str(device) self.local_rank = 0 self.global_rank = 0 self.world_size = 1 - if isinstance(model, paddle.DataParallel): - # 注意这里的 unwrap_model 调用的是具体子类的方法; - model = self.unwrap_model() - if hasattr(model, "train_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `train_step` method, which we can not call actually, we will " - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = self.model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `evaluate_step` method, which we can not call actually, we " - "will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = self.model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `test_step` method, which we can not call actually, we will " - "call `forward` function instead of `test_step` and you should note that.") - self._test_step = self.model - self._test_signature_fn = model.forward - else: - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - # 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的; - model = self.unwrap_model() - self._train_signature_fn = model.forward - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.test_step - else: - self._validate_step = self.model - model = self.unwrap_model() - self._validate_signature_fn = model.forward - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.evaluate_step - else: - self._test_step = self.model - model = self.unwrap_model() - self._test_signature_fn = model.forward - def setup(self): device = self.model_device - if device != "cpu": - device_id = get_paddle_device_id(device) - device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) - device = get_device_from_visible(device, output_type=str) + device = get_device_from_visible(device, output_type=str) paddle.device.set_device(device) self.model.to(device) - def train_step(self, batch) -> Dict: - # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._train_step(batch) + return fn(batch) - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - - def validate_step(self, batch) -> Dict: - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._validate_step(batch) - - def test_step(self, batch) -> Dict: - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) - else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def move_data_to_device(self, batch: 'paddle.Tensor'): r""" @@ -164,12 +115,18 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = RandomBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + if isinstance(args.sampler, paddle.io.RandomSampler): + # 如果本来就是随机的,直接替换 + sampler = RandomSampler(args.sampler.data_source) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + else: + batch_sampler = RandomBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 2f74cc65..feb5c3eb 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -11,7 +11,6 @@ from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to -from fastNLP.core.samplers import RandomSampler from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -87,8 +86,6 @@ class ForwardState(IntEnum): TEST = 2 PREDICT = 3 -_MODE_PARAMETER = "forward_state" - class _FleetWrappingModel(Layer): """ 参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 @@ -98,83 +95,16 @@ class _FleetWrappingModel(Layer): super(_FleetWrappingModel, self).__init__() self.model = model - if isinstance(model, paddle.DataParallel): - model = model._layers - if hasattr(model, "train_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = self.model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `evaluate_step` method, which we can not call actually, " - "we will call `forward` function instead of `evaluate_step` and you should note that.") - self._validate_step = self.model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning( - "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = self.model - self._test_signature_fn = model.forward - else: - if hasattr(model, "train_step"): - self._train_step = model.train_step - self._train_signature_fn = None - else: - self._train_step = model - self._train_signature_fn = model.forward - - if hasattr(model, "evaluate_step"): - self._validate_step = model.validate_step - self._validate_signature_fn = None - elif hasattr(model, "test_step"): - self._validate_step = model.test_step - self._validate_signature_fn = None - else: - self._validate_step = model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - self._test_step = model.test_step - self._test_signature_fn = None - elif hasattr(model, "evaluate_step"): - self._test_step = model.validate_step - self._test_signature_fn = None - else: - self._test_step = model - self._test_signature_fn = model.forward - def forward(self, batch, **kwargs) -> Dict: - forward_state = kwargs.pop(_MODE_PARAMETER) + fn = kwargs.pop("fastnlp_fn") + signature_fn = kwargs.pop("fastnlp_signature_fn") wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) - else: - return self._train_step(batch) - elif forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) - else: - return self._validate_step(batch) - elif forward_state == ForwardState.TEST: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) - else: - return self._test_step(batch) - elif forward_state == ForwardState.PREDICT: - raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") + if isinstance(batch, Dict) and not wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - raise NotImplementedError("You should direct a concrete evaluate_fn.") + return fn(batch) class DummyGradScaler: """ diff --git a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py index 2f4526ac..20be8a37 100644 --- a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py +++ b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py @@ -1,6 +1,7 @@ -from typing import Optional, Dict, Union, Callable +from typing import Optional, Dict, Union, Callable, Tuple from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +from fastNLP.core.utils.utils import _get_fun_msg if _NEED_IMPORT_PADDLE: @@ -48,33 +49,6 @@ class TorchPaddleDriver(Driver): elif self._data_device is not None: raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - self._train_signature_fn = self.model.forward - - if hasattr(self.model, "evaluate_step"): - self._validate_step = self.model.evaluate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.forward - else: - self._validate_step = self.model - self._validate_signature_fn = self.model.forward - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "evaluate_step"): - self._test_step = self.model.evaluate_step - self._test_signature_fn = self.model.forward - else: - self._test_step = self.model - self._test_signature_fn = self.model.forward - def setup(self): if self.model_device is not None: paddle.device.set_device(self.model_device.replace("cuda", "gpu")) @@ -103,12 +77,6 @@ class TorchPaddleDriver(Driver): f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, " f"not {type(each_optimizer)}.") - def train_step(self, batch) -> Dict: - if isinstance(batch, Dict): - return auto_param_call(self._train_step, batch) - else: - return self._train_step(batch) - def step(self): for optimizer in self.optimizers: optimizer.step() @@ -125,17 +93,24 @@ class TorchPaddleDriver(Driver): else: raise ValueError("Unknown optimizers type.") - def validate_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._validate_step, batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._validate_step(batch) + return fn(batch) - def test_step(self, batch): - if isinstance(batch, Dict): - return auto_param_call(self._test_step, batch) + def get_model_call_fn(self, fn: str) -> Tuple: + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') + return fn, None + elif fn in {"train_step", "evaluate_step"}: + logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') + return self.model, self.model.forward else: - return self._test_step(batch) + raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") def predict_step(self, batch): if isinstance(batch, Dict): diff --git a/fastNLP/modules/mix_modules/mix_module.py b/fastNLP/modules/mix_modules/mix_module.py index 2ee26133..1c2bd9e1 100644 --- a/fastNLP/modules/mix_modules/mix_module.py +++ b/fastNLP/modules/mix_modules/mix_module.py @@ -85,7 +85,7 @@ class MixModule: def test_step(self, batch): raise NotImplementedError - def validate_step(self, batch): + def evaluate_step(self, batch): raise NotImplementedError def train(self): diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 69b16427..8a3ab2ce 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -1,13 +1,11 @@ import pytest import os os.environ["FASTNLP_BACKEND"] = "paddle" -from typing import Any from dataclasses import dataclass from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.callbacks.progress_callback import RichCallback -from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK from paddle.optimizer import Adam from paddle.io import DataLoader @@ -19,40 +17,18 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordM from tests.helpers.utils import magic_argv_env_context @dataclass -class MNISTTrainPaddleConfig: +class TrainPaddleConfig: num_labels: int = 10 - feature_dimension: int = 784 + feature_dimension: int = 10 - batch_size: int = 32 + batch_size: int = 2 shuffle: bool = True - validate_every = -5 + evaluate_every = 2 - driver: str = "paddle" - device = "gpu" - -@dataclass -class MNISTTrainFleetConfig: - num_labels: int = 10 - feature_dimension: int = 784 - - batch_size: int = 32 - shuffle: bool = True - validate_every = -5 - -@dataclass -class TrainerParameters: - model: Any = None - optimizers: Any = None - train_dataloader: Any = None - validate_dataloaders: Any = None - input_mapping: Any = None - output_mapping: Any = None - metrics: Any = None - -@pytest.mark.parametrize("driver,device", [("paddle", "cpu")("paddle", 1)]) +@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), - RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) +@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5)]]) @magic_argv_env_context def test_trainer_paddle( driver, @@ -60,38 +36,36 @@ def test_trainer_paddle( callbacks, n_epochs=2, ): - trainer_params = TrainerParameters() - - trainer_params.model = PaddleNormalModel_Classification_1( - num_labels=MNISTTrainPaddleConfig.num_labels, - feature_dimension=MNISTTrainPaddleConfig.feature_dimension + model = PaddleNormalModel_Classification_1( + num_labels=TrainPaddleConfig.num_labels, + feature_dimension=TrainPaddleConfig.feature_dimension ) - trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) + optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001) train_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(6400, 10), - batch_size=MNISTTrainPaddleConfig.batch_size, + dataset=PaddleRandomMaxDataset(20, 10), + batch_size=TrainPaddleConfig.batch_size, shuffle=True ) val_dataloader = DataLoader( - dataset=PaddleRandomMaxDataset(1000, 10), - batch_size=MNISTTrainPaddleConfig.batch_size, + dataset=PaddleRandomMaxDataset(20, 10), + batch_size=TrainPaddleConfig.batch_size, shuffle=True ) - trainer_params.train_dataloader = train_dataloader - trainer_params.validate_dataloaders = val_dataloader - trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every - trainer_params.metrics = {"acc": Accuracy(backend="paddle")} + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainPaddleConfig.evaluate_every + metrics = {"acc": Accuracy(backend="paddle")} trainer = Trainer( - model=trainer_params.model, + model=model, driver=driver, device=device, - optimizers=trainer_params.optimizers, - train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, - input_mapping=trainer_params.input_mapping, - output_mapping=trainer_params.output_mapping, - metrics=trainer_params.metrics, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + input_mapping=None, + output_mapping=None, + metrics=metrics, n_epochs=n_epochs, callbacks=callbacks, diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 9661c015..fd947c73 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -56,34 +56,57 @@ def test_save_and_load_with_randombatchsampler(only_state_dict): dataset=dataset, batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) + num_consumed_batches = 2 # TODO 断点重训完善后在这里迭代几次 + already_seen_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_set.update(batch) sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 # 2. 检查 batch_sampler 是否被正确地加载和替换 - replaced_loader = states["dataloader"] assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) # 4. 检查 batch_idx - # TODO + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_batches = set() + for idx, batch in enumerate(replaced_loader): + left_batches.update(batch) + + assert len(left_batches) + len(already_seen_set) == len(dataset) + assert len(left_batches | already_seen_set) == len(dataset) + + finally: synchronize_safe_rm(path) @@ -104,21 +127,36 @@ def test_save_and_load_with_randomsampler(only_state_dict): dataset, batch_sampler=batch_sampler ) + num_consumed_batches = 2 # TODO 断点重训完善后在这里迭代几次 + already_seen_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_set.update(batch) sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - driver1.save(Path(path), {}, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) - states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) + + # 加载 + # 更改 batch_size + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) + ) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 # 2. 检查 sampler 是否被正确地加载和替换 - replaced_loader = states["dataloader"] + replaced_loader = load_states["dataloader"] assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] @@ -129,60 +167,51 @@ def test_save_and_load_with_randomsampler(only_state_dict): # 3. 检查 model 的参数是否被正确加载 for batch in dataloader: - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) # 4. 检查 batch_idx - # TODO + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_batches = set() + for idx, batch in enumerate(replaced_loader): + left_batches.update(batch) + + assert len(left_batches) + len(already_seen_set) == len(dataset) + assert len(left_batches | already_seen_set) == len(dataset) finally: synchronize_safe_rm(path) -def test_save_and_load_state_dict(prepare_test_save_load): +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_model(prepare_test_save_load, only_state_dict): """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 - """ - try: - path = "dict" - driver1, driver2, dataloader = prepare_test_save_load - - driver1.save_model(path) - driver2.load_model(path) - - for batch in dataloader: - batch = driver1.move_data_to_device(batch) - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) - - assert paddle.equal_all(res1["pred"], res2["pred"]) - finally: - synchronize_safe_rm(path) - -def test_save_and_load_whole_model(prepare_test_save_load): - """ - 测试save和load函数 - TODO optimizer的state_dict为空,暂时不测试 + 测试 save_model 和 load_model 函数 """ try: path = "model" driver1, driver2, dataloader = prepare_test_save_load - driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) - driver2.load_model(path, only_state_dict=False) + if only_state_dict: + driver1.save_model(path, only_state_dict) + else: + driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))]) + driver2.load_model(path, only_state_dict) for batch in dataloader: batch = driver1.move_data_to_device(batch) - res1 = driver1.validate_step(batch) - res2 = driver2.validate_step(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) assert paddle.equal_all(res1["pred"], res2["pred"]) finally: - synchronize_safe_rm(path + ".pdiparams") - synchronize_safe_rm(path + ".pdiparams.info") - synchronize_safe_rm(path + ".pdmodel") - + if only_state_dict: + synchronize_safe_rm(path) + else: + synchronize_safe_rm(path + ".pdiparams") + synchronize_safe_rm(path + ".pdiparams.info") + synchronize_safe_rm(path + ".pdmodel") class TestSingleDeviceFunction: """ @@ -199,13 +228,7 @@ class TestSingleDeviceFunction: 测试能否运行 """ res = self.driver.unwrap_model() - - def test_check_evaluator_mode(self): - """ - 这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素 - """ - self.driver.check_evaluator_mode("validate") - self.driver.check_evaluator_mode("test") + assert res is self.driver.model def test_is_distributed(self): assert self.driver.is_distributed() == False @@ -237,21 +260,30 @@ class TestSetDistReproDataloder: assert replaced_loader is dataloader - def test_set_dist_repro_dataloader_with_reproducible_true(self): + @pytest.mark.parametrize("shuffle", [True, False]) + def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 - 当dist为字符串时,此时应该返回新的 dataloader,且 batch_sampler 为 RandomBatchSampler + 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) - assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + if shuffle: + # 此时会替换 sampler + assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + else: + # 此时会替换 batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - # self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader) def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): """ diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 751d59f2..c3a9d4da 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -72,7 +72,7 @@ class RecordTrainerEventTriggerCallback(Callback): print("on_train_end") def on_train_epoch_begin(self, trainer): - if trainer.current_epoch_idx >= 1: + if trainer.cur_epoch_idx >= 1: # 触发 on_exception; raise Exception print("on_train_epoch_begin") diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index a830b1ff..efa8c0ce 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -26,7 +26,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): x = self(x) return {"loss": self.loss_fn(x, y)} - def validate_step(self, x, y): + def evaluate_step(self, x, y): x = self(x) return {"pred": x, "target": y.reshape((-1,))} From 288eb36afbae4d3583e662fbc50f7eb9920b128f Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 09:05:44 +0000 Subject: [PATCH 5/8] =?UTF-8?q?=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=20save?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E9=80=BB=E8=BE=91=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 37a5e59e..3b8ad7d8 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -244,17 +244,18 @@ class PaddleDriver(Driver): if hasattr(sampler, "state_dict") and callable(sampler.state_dict): sampler_states = sampler.state_dict() # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 - num_consumed_samples_array = sampler_states.pop("num_consumed_samples_array", None) + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: - sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] - else: - try: - sampler_states["num_consumed_samples"] = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 - pass - assert sampler_states["num_consumed_samples"] != -1, "This is a bug, please report." - states["sampler_states"] = sampler_states + if isinstance(sampler, ReproducibleSampler): + # 如果是 sampler 的话,需要计算出实际的 sample 数目 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + states['sampler_states'] = sampler_states else: raise RuntimeError( "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") From cf19062fb25f08cb527013636e918231b5d123b6 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 09:06:22 +0000 Subject: [PATCH 6/8] =?UTF-8?q?set=5Fdist=5Frepro=5Fdataloader=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BE=8B=E7=9A=84=E5=AE=8C=E5=96=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/paddle_driver/test.py | 25 +++++++ tests/core/drivers/paddle_driver/test2.py | 21 ++++++ .../core/drivers/paddle_driver/test_fleet.py | 67 ++++++++++-------- .../paddle_driver/test_single_device.py | 69 ++++++++++++------- 4 files changed, 129 insertions(+), 53 deletions(-) create mode 100644 tests/core/drivers/paddle_driver/test.py create mode 100644 tests/core/drivers/paddle_driver/test2.py diff --git a/tests/core/drivers/paddle_driver/test.py b/tests/core/drivers/paddle_driver/test.py new file mode 100644 index 00000000..5455a230 --- /dev/null +++ b/tests/core/drivers/paddle_driver/test.py @@ -0,0 +1,25 @@ +import sys +import os +import warnings +warnings.filterwarnings("ignore") +os.environ["FASTNLP_BACKEND"] = "torch" +sys.path.append("../../../../") + +import paddle +from fastNLP.core.samplers import RandomSampler +from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler +from tests.helpers.datasets.paddle_data import PaddleNormalDataset + +dataset = PaddleNormalDataset(20) +batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2) +batch_sampler.sampler = RandomSampler(dataset, True) +dataloader = paddle.io.DataLoader( + dataset, + batch_sampler=batch_sampler +) + +forward_steps = 9 +iter_dataloader = iter(dataloader) +for _ in range(forward_steps): + print(next(iter_dataloader)) +print(dataloader.batch_sampler.sampler.during_iter) diff --git a/tests/core/drivers/paddle_driver/test2.py b/tests/core/drivers/paddle_driver/test2.py new file mode 100644 index 00000000..aaa3150e --- /dev/null +++ b/tests/core/drivers/paddle_driver/test2.py @@ -0,0 +1,21 @@ +import torch +# from torch.utils.data import DataLoader, Dataset +import paddle +from paddle.io import Dataset, DataLoader +paddle.device.set_device("cpu") +class NormalDataset(Dataset): + def __init__(self, num_of_data=1000): + self.num_of_data = num_of_data + self._data = list(range(num_of_data)) + + def __len__(self): + return self.num_of_data + + def __getitem__(self, item): + return self._data[item] +dataset = NormalDataset(20) +dataloader = DataLoader(dataset, batch_size=2, use_buffer_reader=False) +for i, b in enumerate(dataloader): + print(b) + if i >= 2: + break diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 434e9e5b..de98f9c5 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -117,12 +117,13 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) - batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) @@ -133,12 +134,13 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) - sampler = RandomSampler(self.dataset, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + sampler = RandomSampler(self.dataset, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) @@ -171,14 +173,15 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 时的表现 """ dataloader = DataLoader( self.dataset, - batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle), ) dataloader.batch_sampler.set_distributed( num_replicas=self.driver.world_size, @@ -195,12 +198,13 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler.set_distributed( num_replicas=self.driver.world_size, rank=self.driver.global_rank @@ -222,11 +226,12 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader @@ -238,14 +243,15 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 的表现 """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) @@ -258,13 +264,14 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -276,16 +283,17 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) @@ -293,7 +301,7 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle dist.barrier() """ @@ -302,13 +310,14 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -320,18 +329,19 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) + batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -349,11 +359,12 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index fd947c73..ebd4721b 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,4 +1,5 @@ import os +from re import S os.environ["FASTNLP_BACKEND"] = "paddle" import pytest from pathlib import Path @@ -283,30 +284,32 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler is dist - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dist_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) - dist = RandomSampler(self.dataset, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomSampler(self.dataset, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) @@ -316,16 +319,21 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.sampler is dist assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + batch_sampler=RandomBatchSampler( + BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), + batch_size=4, + drop_last=False, + ) ) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) @@ -335,15 +343,16 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -355,11 +364,11 @@ class TestSetDistReproDataloder: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): """ 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -378,9 +387,6 @@ class TestSetDistReproDataloder: # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - import time - time.sleep(5) - # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): @@ -389,16 +395,29 @@ class TestSetDistReproDataloder: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size - replaced_loader.batch_sampler.load_state_dict(sampler_states) + # 重新改造 dataloader + new_loader = DataLoader( + dataset=replaced_loader.dataset, + batch_sampler=RandomBatchSampler( + BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), + batch_size=batch_size, + drop_last=False, + ) + ) + new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size + num_consumed_batches = num_consumed_batches * batch_size if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size - replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) - replaced_loader.batch_sampler.sampler.set_epoch(0) - for idx, batch in enumerate(replaced_loader): + # 重新构造 dataloader + batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) + batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) + new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): left_idxes.update(batch) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) From 16cec4bd99d55aec8cceb06890c3f1ea5506dcce Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 09:07:35 +0000 Subject: [PATCH 7/8] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=A6=81=E7=9A=84=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/paddle_driver/test.py | 25 ----------------------- tests/core/drivers/paddle_driver/test2.py | 21 ------------------- 2 files changed, 46 deletions(-) delete mode 100644 tests/core/drivers/paddle_driver/test.py delete mode 100644 tests/core/drivers/paddle_driver/test2.py diff --git a/tests/core/drivers/paddle_driver/test.py b/tests/core/drivers/paddle_driver/test.py deleted file mode 100644 index 5455a230..00000000 --- a/tests/core/drivers/paddle_driver/test.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys -import os -import warnings -warnings.filterwarnings("ignore") -os.environ["FASTNLP_BACKEND"] = "torch" -sys.path.append("../../../../") - -import paddle -from fastNLP.core.samplers import RandomSampler -from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler -from tests.helpers.datasets.paddle_data import PaddleNormalDataset - -dataset = PaddleNormalDataset(20) -batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2) -batch_sampler.sampler = RandomSampler(dataset, True) -dataloader = paddle.io.DataLoader( - dataset, - batch_sampler=batch_sampler -) - -forward_steps = 9 -iter_dataloader = iter(dataloader) -for _ in range(forward_steps): - print(next(iter_dataloader)) -print(dataloader.batch_sampler.sampler.during_iter) diff --git a/tests/core/drivers/paddle_driver/test2.py b/tests/core/drivers/paddle_driver/test2.py deleted file mode 100644 index aaa3150e..00000000 --- a/tests/core/drivers/paddle_driver/test2.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -# from torch.utils.data import DataLoader, Dataset -import paddle -from paddle.io import Dataset, DataLoader -paddle.device.set_device("cpu") -class NormalDataset(Dataset): - def __init__(self, num_of_data=1000): - self.num_of_data = num_of_data - self._data = list(range(num_of_data)) - - def __len__(self): - return self.num_of_data - - def __getitem__(self, item): - return self._data[item] -dataset = NormalDataset(20) -dataloader = DataLoader(dataset, batch_size=2, use_buffer_reader=False) -for i, b in enumerate(dataloader): - print(b) - if i >= 2: - break From 687db6d86aff6a5bc5863790bf2a29235bc67da7 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 15 Apr 2022 19:31:46 +0800 Subject: [PATCH 8/8] =?UTF-8?q?1.torch=E5=9C=A8=E4=BF=9D=E5=AD=98=E5=92=8C?= =?UTF-8?q?load=E7=9A=84=E6=97=B6=E5=80=99=E4=BC=9A=E8=80=83=E8=99=91GradS?= =?UTF-8?q?caler=E7=9A=84=E4=BF=9D=E5=AD=98=E9=97=AE=E9=A2=98;=202.?= =?UTF-8?q?=E6=96=B0=E5=A2=9ETorch=E7=9A=84GradientClip=E5=92=8CWarmpup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 140 +------------ fastNLP/core/callbacks/checkpoint_callback.py | 7 +- fastNLP/core/callbacks/early_stop_callback.py | 2 +- .../core/callbacks/has_monitor_callback.py | 189 ++++++++++++++++++ .../callbacks/load_best_model_callback.py | 7 +- fastNLP/core/callbacks/progress_callback.py | 2 +- .../torch_driver/initialize_torch_driver.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 30 ++- fastNLP/core/utils/utils.py | 38 +++- tests/envs/test_set_backend.py | 2 +- 10 files changed, 255 insertions(+), 164 deletions(-) create mode 100644 fastNLP/core/callbacks/has_monitor_callback.py diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index b37eda63..117cb524 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -1,16 +1,12 @@ -from typing import Union, Callable, Dict, Optional, Any -from abc import ABC __all__ = [ 'Callback', ] +from typing import Union, Callable, Dict, Optional, Any + from .callback_events import Events, EventsList, Filter -from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState -from fastNLP.core.log import logger -from fastNLP.core.utils import apply_to_collection -from fastNLP.core.utils.utils import _check_valid_parameters_number class Callback: @@ -278,135 +274,3 @@ class _CallbackWrapper(Callback): @property def callback_name(self): return self.fn.__name__ - - -class CanItemDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanItemDataType: - item = getattr(subclass, 'item', None) - return callable(item) - return NotImplemented - - -class HasMonitorCallback(Callback): - def __init__(self, monitor, larger_better, must_have_monitor=False): - self.set_monitor(monitor, larger_better) - self.must_have_moinitor = must_have_monitor - - def set_monitor(self, monitor, larger_better): - if callable(monitor): # 检查是否能够接受一个参数 - _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') - self.monitor = monitor - else: - self.monitor = str(monitor) if monitor is not None else None - self.larger_better = bool(larger_better) - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = self.monitor - - def on_after_trainer_initialized(self, trainer, driver): - """ - 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 - 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 - - :param trainer: - :param driver: - :return: - """ - if self.monitor is None and trainer.monitor is not None: - self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) - if self.must_have_moinitor and self.monitor is None: - raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " - f"You can set it in the initialization or through Trainer.") - - def get_monitor_value(self, results:Dict)->Union[float, None]: - """ - 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 - - :param results: - :return: 如果为 None ,表明此次没有找到合适的monitor - """ - if len(results)==0: - return None - # 保证所有的 tensor 都被转换为了 python 特定的类型 - results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) - use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if monitor_value is None: - return monitor_value - # 第一次运行 - if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: - logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " - f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") - # 检测到此次和上次不同。 - elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: - logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " - f"The expected monitor is:`{self.monitor}`, last used monitor is:" - f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " - f"customized monitor function when the evaluation results are varying between validation.") - - self._real_monitor = use_monitor - return monitor_value - - def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): - """ - 检测 monitor_value 是否是更好的 - - :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False - :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 - :return: - """ - if monitor_value is None: - return False - better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) - if keep_if_better and better: - self.monitor_value = monitor_value - return better - - def is_former_monitor_value_better(self, monitor_value1, monitor_value2): - """ - 传入的两个值中,是否monitor_value1的结果更好。 - - :param monitor_value1: - :param monitor_value2: - :return: - """ - if monitor_value1 is None and monitor_value2 is None: - return True - if monitor_value1 is None: - return False - if monitor_value2 is None: - return True - better = False - if (self.larger_better and monitor_value1 > monitor_value2) or \ - (not self.larger_better and monitor_value1 < monitor_value2): - better = True - return better - - @property - def monitor_name(self): - """ - 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 - - :return: - """ - if callable(self.monitor): - try: - monitor_name = self.monitor.__qualname__ - except: - monitor_name = self.monitor.__name__ - elif self.monitor is None: - return None - else: - # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 - monitor_name = str(self.monitor) - return monitor_name diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index d2d97294..b13632d1 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -10,9 +10,9 @@ from copy import deepcopy import fastNLP -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_LAUNCH_TIME +from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir @@ -217,7 +217,8 @@ class CheckpointCallback(HasMonitorCallback): :return: """ folder = self.timestamp_path.joinpath(folder_name) - synchronize_mkdir(folder) + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只在进程0上创建 + synchronize_mkdir(folder) _fn = getattr(trainer, self.save_fn_name) _fn( folder=folder, diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index c679ad7e..0923eb00 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -4,7 +4,7 @@ __all__ = [ from typing import Dict, Union, Callable -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils.exceptions import EarlyStopException diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py new file mode 100644 index 00000000..54bd9bb4 --- /dev/null +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -0,0 +1,189 @@ +__all__ = [ + 'HasMonitorCallback', + 'ExecuteOnceBetterMonitor' +] + +from typing import Dict, Union, Any +from abc import ABC + +from fastNLP.core.utils import apply_to_collection +from fastNLP.core.callbacks import Callback +from fastNLP.core.callbacks.utils import _get_monitor_value +from fastNLP.core.log import logger +from fastNLP.core.utils.utils import _check_valid_parameters_number + + +class CanItemDataType(ABC): + """ + 检测可以进行传输的对象。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + + +class HasMonitorCallback(Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + """ + 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 + (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: monitor 是否时越大越好 + :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 + """ + self.set_monitor(monitor, larger_better) + self.must_have_moinitor = must_have_monitor + + def set_monitor(self, monitor, larger_better): + if callable(monitor): # 检查是否能够接受一个参数 + _check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor') + self.monitor = monitor + else: + self.monitor = str(monitor) if monitor is not None else None + self.larger_better = bool(larger_better) + if larger_better: + self.monitor_value = float('-inf') + else: + self.monitor_value = float('inf') + self._real_monitor = self.monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_moinitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + + def get_monitor_value(self, results:Dict)->Union[float, None]: + """ + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 + + :param results: + :return: 如果为 None ,表明此次没有找到合适的monitor + """ + if len(results)==0: + return None + # 保证所有的 tensor 都被转换为了 python 特定的类型 + results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, + real_monitor=self._real_monitor, + res=results) + if monitor_value is None: + return monitor_value + # 第一次运行 + if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: + logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") + # 检测到此次和上次不同。 + elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: + logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " + f"The expected monitor is:`{self.monitor}`, last used monitor is:" + f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " + f"customized monitor function when the evaluation results are varying between validation.") + + self._real_monitor = use_monitor + return monitor_value + + def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): + """ + 检测 monitor_value 是否是更好的 + + :param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False + :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :return: + """ + if monitor_value is None: + return False + better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) + if keep_if_better and better: + self.monitor_value = monitor_value + return better + + def is_better_results(self, results, keep_if_better=True): + """ + 检测给定的 results 是否比上一次更好,如果本次 results 中没有找到相关的monitor 返回 False。 + + :param results: on_valid_ends() 接口中传入的 evaluation 结果。 + :param keep_if_better: 当返回为 True 时,是否保存到 self.monitor_value 中。 + :return: + """ + monitor_value = self.get_monitor_value(results) + if monitor_value is None: + return False + return self.is_better_monitor_value(monitor_value, keep_if_better=keep_if_better) + + def is_former_monitor_value_better(self, monitor_value1, monitor_value2): + """ + 传入的两个值中,是否monitor_value1的结果更好。 + + :param monitor_value1: + :param monitor_value2: + :return: + """ + if monitor_value1 is None and monitor_value2 is None: + return True + if monitor_value1 is None: + return False + if monitor_value2 is None: + return True + better = False + if (self.larger_better and monitor_value1 > monitor_value2) or \ + (not self.larger_better and monitor_value1 < monitor_value2): + better = True + return better + + @property + def monitor_name(self): + """ + 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 + + :return: + """ + if callable(self.monitor): + try: + monitor_name = self.monitor.__qualname__ + except: + monitor_name = self.monitor.__name__ + elif self.monitor is None: + return None + else: + # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 + monitor_name = str(self.monitor) + return monitor_name + + +class ExecuteOnceBetterMonitor(HasMonitorCallback): + def __init__(self, monitor, larger_better, execute_fn): + """ + 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: monitor 是否时越大越好 + :param execute_fn: 一个可执行的函数,不接受任何参数,不反回值。在 monitor 取得更好结果的时候会调用。 + """ + super().__init__(monitor, larger_better, must_have_monitor=True) + _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') + self.execute_fn = execute_fn() + + def on_validate_end(self, trainer, results): + if self.is_better_results(results): + self.execute_fn() \ No newline at end of file diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 09f85d01..f240caa7 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -4,7 +4,7 @@ __all__ = [ import os from typing import Optional, Callable, Union -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from io import BytesIO import shutil @@ -80,10 +80,7 @@ class LoadBestModelCallback(HasMonitorCallback): self.get_monitor_value(sanity_check_res) def on_validate_end(self, trainer, results): - monitor_value = self.get_monitor_value(results) - if monitor_value is None: - return - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_save_fn=self.model_save_fn) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index f351f204..bb638122 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -8,7 +8,7 @@ __all__ = [ 'RichCallback' ] -from .callback import HasMonitorCallback +from .has_monitor_callback import HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 2c9c5162..f149855f 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic # world_size 和 rank if FASTNLP_BACKEND_LAUNCH in os.environ: if device is not None: - logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " + logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " "up your script. And we will directly get the local device via " "`os.environ['LOCAL_RANK']`.") return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 233d7040..f00d3f1f 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -25,7 +25,7 @@ __all__ = [ from .utils import optimizer_state_to_device from fastNLP.core.drivers.driver import Driver -from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env +from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME @@ -224,6 +224,11 @@ class TorchDriver(Driver): optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + # 4. 保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + logger.debug("Save optimizer state dict") states["optimizers_state_dict"] = optimizers_state_dict torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) @@ -232,7 +237,7 @@ class TorchDriver(Driver): states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; - optimizers_state_dict = states["optimizers_state_dict"] + optimizers_state_dict = states.pop("optimizers_state_dict") for i in range(len(self.optimizers)): optimizer: torch.optim.Optimizer = self.optimizers[i] optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) @@ -244,26 +249,37 @@ class TorchDriver(Driver): res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') if only_state_dict: model.load_state_dict(res) - logger.debug("Load model state dict.") + logger.debug("Load model state dict...") else: model.load_state_dict(res.state_dict()) - logger.debug("Load model.") + logger.debug("Load model...") - # 3. 恢复 sampler 的状态; + # 3. 加载fp16的状态 + if 'grad_scaler_state_dict' in states: + grad_scaler_state_dict = states.pop('grad_scaler_state_dict') + if not isinstance(self.grad_scaler, DummyGradScaler): + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + elif not isinstance(self.grad_scaler, DummyGradScaler): + logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " + "`ReproducibleSampler`.") else: sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) - sampler.load_state_dict(states['sampler_states']) + sampler.load_state_dict(states.pop('sampler_states')) states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 729ca960..e0d94cc8 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -203,7 +203,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): :return: """ if fn_name is not None: - assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}." + assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." parameters = list(inspect.signature(fn).parameters.values()) if inspect.ismethod(fn): @@ -606,16 +606,38 @@ def seq_len_to_mask(seq_len, max_len=None): return mask -def wait_to_success(fn, no=False): +def wait_filepath(path, exist=True): + """ + 等待当 path 的存在状态为 {exist} 时返回 + + :param path: 待检测的 path + :param exist: 为 True 时表明检测这个 path 存在就返回; 为 False 表明检测到这个 path 不存在 返回。 + :return: + """ + if isinstance(path, str): + path = Path(path) + assert isinstance(path, Path) + count = 0 while True: sleep(0.01) - if (no and not fn()) or (not no and fn()): + if path.exists() == exist: break + count += 1 + if count % 1000 == 0: + msg = 'create' if exist else 'delete' + logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") + -# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 -# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; def synchronize_safe_rm(path: Optional[Union[str, Path]]): + """ + 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 + 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; + 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + + :param path: + :return: + """ if path is None: return if isinstance(path, str): @@ -624,7 +646,7 @@ def synchronize_safe_rm(path: Optional[Union[str, Path]]): return if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: _recursive_rm(path) - wait_to_success(path.exists, no=True) + wait_filepath(path, exist=False) def _recursive_rm(path: Path): @@ -643,6 +665,8 @@ def _recursive_rm(path: Path): def synchronize_mkdir(path: Optional[Union[str, Path]]): """ 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; + 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + """ if path is None: return @@ -652,7 +676,7 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]): if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: path.mkdir(parents=True, exist_ok=True) - wait_to_success(path.exists) + wait_filepath(path, exist=True) def get_class_that_defined_method(method): diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 2c8fbadf..03931bdc 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,6 +1,6 @@ import os -from fastNLP.envs.set_env import dump_fastnlp_backend +from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing from fastNLP.core import synchronize_safe_rm