fix conflict

This commit is contained in:
x54-729 2022-04-11 14:16:39 +00:00
commit 3a3c38a44e
46 changed files with 1814 additions and 853 deletions

131
README.md
View File

@ -6,4 +6,133 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
dev0.8.0正在开发中
fastNLP是一款轻量级的自然语言处理NLP工具包目标是快速实现NLP任务以及构建复杂模型。
fastNLP具有如下的特性
- 统一的Tabular式数据容器简化数据预处理过程
- 内置多种数据集的Loader和Pipe省去预处理代码;
- 各种方便的NLP工具例如Embedding加载包括ELMo和BERT、中间数据cache等;
- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
- Trainer提供多种内置Callback函数方便实验记录、异常捕获等。
## 安装指南
fastNLP 依赖以下包:
+ numpy>=1.14.2
+ torch>=1.0.0
+ tqdm>=4.28.1
+ nltk>=3.4.1
+ requests
+ spacy
+ prettytable>=0.7.2
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
在依赖包安装完成后,您可以在命令行执行如下指令完成安装
```shell
pip install fastNLP
python -m spacy download en
```
## fastNLP教程
中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
### 快速入门
- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
### 详细使用教程
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
### 扩展教程
- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
## 内置组件
大部分用于的 NLP 任务神经网络都可以看做由词嵌入embeddings和两种模块编码器encoder、解码器decoder组成。
以文本分类任务为例下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图
![](./docs/source/figures/text_classification.png)
fastNLP 在 embeddings 模块中内置了几种不同的embedding静态embeddingGloVe、word2vec、上下文相关embedding
ELMo、BERT、字符embedding基于CNN或者LSTM的CharEmbedding
与此同时fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
<table>
<tr>
<td><b> 类型 </b></td>
<td><b> 功能 </b></td>
<td><b> 例子 </b></td>
</tr>
<tr>
<td> encoder </td>
<td> 将输入编码为具有具有表示能力的向量 </td>
<td> Embedding, RNN, CNN, Transformer, ...
</tr>
<tr>
<td> decoder </td>
<td> 将具有某种表示意义的向量解码为需要的输出形式 </td>
<td> MLP, CRF, ... </td>
</tr>
</table>
## 项目结构
<div align=center><img width="450" height="350" src="./docs/source/figures/workflow.png"/></div>
fastNLP的大致工作流程如上图所示而项目结构如下
<table>
<tr>
<td><b> fastNLP </b></td>
<td> 开源的自然语言处理库 </td>
</tr>
<tr>
<td><b> fastNLP.core </b></td>
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td>
</tr>
<tr>
<td><b> fastNLP.models </b></td>
<td> 实现了一些完整的神经网络模型 </td>
</tr>
<tr>
<td><b> fastNLP.modules </b></td>
<td> 实现了用于搭建神经网络模型的诸多组件 </td>
</tr>
<tr>
<td><b> fastNLP.embeddings </b></td>
<td> 实现了将序列index转为向量序列的功能包括读取预训练embedding等 </td>
</tr>
<tr>
<td><b> fastNLP.io </b></td>
<td> 实现了读写功能,包括数据读入与预处理,模型读写,数据与模型自动下载等 </td>
</tr>
</table>
<hr>
*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*

View File

@ -4,7 +4,8 @@ __all__ = [
'EventsList',
'Filter',
'CallbackManager',
'CheckpointCallback',
'ModelCheckpointCallback',
'TrainerCheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
@ -16,7 +17,7 @@ __all__ = [
from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback

View File

@ -8,7 +8,7 @@ __all__ = [
from .callback_events import Events
from .callback import Callback
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import TrainerCheckpointCallback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger
@ -98,7 +98,7 @@ class CallbackManager:
:return:
"""
for each_callback in self.class_callbacks:
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint:
if isinstance(each_callback, TrainerCheckpointCallback):
self._has_trainer_checkpoint = True
self.dissect_one_callback(each_callback)
@ -210,7 +210,7 @@ class CallbackManager:
each_callback.on_load_checkpoint(trainer, None)
@property
def has_trainer_chechpoint(self) -> bool:
def has_trainer_checkpoint(self) -> bool:
return self._has_trainer_checkpoint
@_transfer

View File

@ -1,12 +1,13 @@
import os
from typing import Union, Optional, Callable, Dict, Sequence
from pathlib import Path
from functools import partial
from time import sleep
__all__ = [
'CheckpointCallback'
'ModelCheckpointCallback',
'TrainerCheckpointCallback'
]
import os
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
from pathlib import Path
from abc import ABC
import sys
import fastNLP
from .callback import Callback, Filter
@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection
class CanItemDataType(ABC):
"""
检测可以进行传输的对象
"""
@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanItemDataType:
item = getattr(subclass, 'item', None)
return callable(item)
return NotImplemented
class CheckpointCallback(Callback):
"""
1. 因为只有 'Trainer' 才有 callback因此评测 metric 实际上就是 validate 时干的事情
2. 默认 'save_last' True model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型模型名字为 last.pth.tar
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视如果用户在训练过程中指定了多个 monitor 的监视例如 "acc1",
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例
4. 理论上在实际保存的过程中topk 模式和 固定频率保存的模式是完全独立的我们确实应当采取一些措施至少保证两者的名字不一样
"""
def __init__(
self,
monitor,
is_trainer_checkpoint: Optional[bool] = False,
save_folder: Optional[Union[str, Path]] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_global_batches: Optional[int] = None,
save_every_n_batches: Optional[int] = None,
save_last: bool = True,
save_topk: Optional[int] = None,
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
larger_better: bool = True,
only_state_dict: bool = True,
model_save_fn: Optional[Callable] = None,
**kwargs,
):
if monitor is None and save_topk is not None:
@ -51,9 +54,6 @@ class CheckpointCallback(Callback):
if monitor is not None and not isinstance(monitor, str):
raise ValueError("Parameter `monitor` should be of 'str' type.")
if not isinstance(is_trainer_checkpoint, bool):
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.")
if save_folder is None:
logger.warning(
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
@ -67,15 +67,15 @@ class CheckpointCallback(Callback):
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1:
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.")
# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样;
self._filter_every_n_epochs = Filter(every=save_every_n_epochs)
else:
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除
if save_every_n_global_batches is not None:
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1:
if save_every_n_batches is not None:
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1:
raise ValueError(
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.")
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches)
"parameter save_every_n_batches should be an int and greater than or equal to 1.")
else:
save_every_n_batches = sys.maxsize # 使得没有数字可以整除
if save_topk is not None:
if not isinstance(save_topk, int) or save_topk < 1:
@ -89,12 +89,12 @@ class CheckpointCallback(Callback):
if not issubclass(exception, BaseException):
raise TypeError("Each exception in parameter `save_on_exception` can only be "
"`BaseException` type.")
else:
save_on_exception = []
self.monitor = monitor
self.is_trainer_checkpoint = is_trainer_checkpoint
self.save_folder = Path(save_folder)
self.save_every_n_epochs = save_every_n_epochs
self.save_every_n_global_batches = save_every_n_global_batches
self.save_every_n_batches = save_every_n_batches
self.save_last = save_last
self.save_topk = save_topk
self.larger_better = larger_better
@ -107,7 +107,7 @@ class CheckpointCallback(Callback):
self._topk_model = {}
self._topn = 0 # 表示目前已经保存了几个最好的模型;
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key
@ -115,76 +115,83 @@ class CheckpointCallback(Callback):
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
synchronize_mkdir(self.log_filepath)
synchronize_mkdir(self.timestamp_path)
def on_validate_end(self, trainer, validate_res):
self._save_topk(trainer, validate_res)
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
self._save_every_n_epochs(trainer)
self._save_last(trainer)
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}'
self.save(trainer, folder_name=folder_name)
if self.save_last:
folder_name = f'{self.folder_prefix}-last'
self.save(trainer, folder_name=folder_name)
def on_train_batch_end(self, trainer):
self._save_every_n_global_batches(trainer)
if trainer.global_forward_batches % self.save_every_n_batches == 0:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
self.save(trainer, folder_name=folder_name)
def on_exception(self, trainer, exception: BaseException):
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception:
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None)
folder = folder + f"_{exception.__class__.__name__}"
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder)
if exception.__class__ in self.save_on_exception:
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
f'exception_{exception.__class__.__name__}'
self.save(trainer=trainer, folder_name=folder_name)
def on_sanity_check_end(self, trainer, sanity_check_res):
# 主要核对一下 monitor 是否存在。
self._get_validate_metric(sanity_check_res)
def on_save_checkpoint(self, trainer) -> Dict:
"""
我们需要保存 CheckpointCallback 内部的几个 filter 的状态
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹
topk_model的状态
_real_monitor的值
"""
states = {}
if self.save_every_n_epochs is not None:
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict()
if self.save_every_n_global_batches is not None:
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict()
states["real_monitor"] = self._real_monitor
states['timestamp_path'] = str(self.timestamp_path.absolute())
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
function=lambda x:x.item())
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
states['_real_monitor'] = self._real_monitor
return states
def on_load_checkpoint(self, trainer, states: Optional[Dict]):
if self.save_every_n_epochs is not None:
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"])
if self.save_every_n_global_batches is not None:
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"])
timestamp_path = states['timestamp_path']
if not os.path.exists(timestamp_path):
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
f" {self.timestamp_path.absolute()}.")
else:
logger.info(f"Resume to save in path: {timestamp_path}.")
self.timestamp_path = Path(timestamp_path)
_topk_model = states['_topk_model']
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
if save_topk is not None and self.save_topk is not None:
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
f"as {save_topk}."
self._topk_model.update(self._topk_model)
self._real_monitor = states["real_monitor"]
def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_epochs is not None:
if self.is_trainer_checkpoint:
_fn_every_n_epochs = trainer.save
else:
_fn_every_n_epochs = trainer.save_model
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None)
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs)
_fn_every_n_epochs()
def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"):
if self.save_every_n_global_batches is not None:
if self.is_trainer_checkpoint:
_fn_every_n_global_batches = trainer.save
else:
_fn_every_n_global_batches = trainer.save_model
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None)
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches)
_fn_every_n_global_batches()
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
"""
根据validate_res决定保存哪些model的函数会自动移除掉不满足topk的文件夹
:param trainer:
:param validate_res:
:return:
"""
if self.save_topk is not None:
_metric_value = self._get_validate_metric(validate_res)
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value)
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{_metric_value}"
_should_save = False
if self._topn < self.save_topk:
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
self._topn += 1
_should_save = True
else:
@ -192,39 +199,27 @@ class CheckpointCallback(Callback):
key=lambda x: self._topk_model[x])
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
self._topk_model[_saved_name] = _metric_value
self._topk_model[folder_name] = _metric_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model))
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))
assert len(self._topk_model) == self.save_topk == self._topn
if _should_save:
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name)
self.save(trainer, folder_name=folder_name)
def _save_last(self, trainer: "fastNLP.Trainer"):
if self.save_last:
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last")
def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None,
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None):
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者
# epoch-batch-monitor-monitor_value
if substitute_folder is None:
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric))
else:
folder = self.log_filepath.joinpath(substitute_folder)
def save(self, trainer, folder_name):
"""
执行保存的函数将数据保存在 save_folder/timestamp/folder_name
:param trainer:
:param folder_name:
:return:
"""
folder = self.timestamp_path.joinpath(folder_name)
synchronize_mkdir(folder)
# 然后再调用 trainer 的 save_model用于保存模型或者 save用于断点重训函数
if substitute_fn is not None:
_fn = substitute_fn
else:
if self.is_trainer_checkpoint:
_fn = trainer.save
else:
_fn = trainer.save_model
_fn = getattr(trainer, self.save_fn_name)
_fn(
folder=folder,
only_state_dict=self.only_state_dict,
@ -243,18 +238,48 @@ class CheckpointCallback(Callback):
self._real_monitor = use_monitor
return value
def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False,
metric: Optional[Union[int, float]] = None) -> str:
@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")
@property
def save_fn_name(self):
raise NotImplementedError("The `save_fn_name` is not specified.")
class ModelCheckpointCallback(CheckpointCallback):
"""
获取当前保存模型的真正地名字
metric 参数仅当 mode 'topk' 时起作用
保存模型 checkpoint callback 其保存的文件目录以及文件名命名规则如下
- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- model-last/ # 最后一个 epoch 的保存
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
model_save_fn None 则以上每个 folder 将生成 fastnlp_model.pkl.tar 文件
model_save_fn 不为 None fastNLP folder 绝对路径传递给该函数fastNLP 不在该 folder 下创建任何文件
:param monitor: 监控的 metric 的名称如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor
:param save_folder: 保存的文件夹fastNLP 将在该文件下以时间戳创建子文件夹并在里面保存因此不同次运行可以将被保存到不同的
时间戳文件夹中如果为 None 默认使用当前文件夹
:param save_every_n_epochs: 多少个 epoch 保存一次
:param save_every_n_batches: 多少个 batch 保存一次
:param save_last: 如果为 True 将在每次 epoch 运行结束都保存一次会覆盖之前的保存
:param save_topk: 保存 monitor 结果 topK
:param save_on_exception: 在出异常信息时是否保存传入需要捕获的异常的类
:param larger_better: monitor 的值是否时越大越好
:param only_state_dict: 保存模型时是否只保存 state_dict model_save_fn 不为 None 该参数无效
:param model_save_fn: 个性化的保存函数当触发保存操作时就调用这个函数这个函数应当接受一个文件夹作为参数不返回任何东西
如果传入了 model_save_fn 函数fastNLP 将不再进行模型相关的保存在多卡场景下我们只在 rank 0 上会运行该函数
:param kwargs:
"""
cur_epoch_idx = trainer.cur_epoch_idx
global_forward_batches = trainer.global_forward_batches
_other = ""
if topk:
_other = f"_{metric}"
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}"
@property
def save_fn_name(self):
return 'save_model'
@property
def callback_name(self):
@ -262,6 +287,55 @@ class CheckpointCallback(Callback):
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态
:return:
"""
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}"
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
@property
def folder_prefix(self):
return 'model'
class TrainerCheckpointCallback(CheckpointCallback):
"""
保存 Trainer checkpoint callback 其保存的文件目录以及文件名命名规则如下
- save_folder/
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
- trainer-last/ # 最后一个 epoch 的保存
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
model_save_fn None 则以上每个 folder 将生成两个文件fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar
model_save_fn 不为 None fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件
:param monitor: 监控的 metric 的名称如果在 evaluation 结果中没有找到完全一致的名称将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor
:param save_folder: 保存的文件夹fastNLP 将在该文件下以时间戳创建子文件夹并在里面保存因此不同次运行可以将被保存到不同的
时间戳文件夹中如果为 None 默认使用当前文件夹
:param save_every_n_epochs: 多少个 epoch 保存一次
:param save_every_n_batches: 多少个 batch 保存一次
:param save_last: 如果为 True 将在每次 epoch 运行结束都保存一次会覆盖之前的保存
:param save_topk: 保存 monitor 结果 topK
:param save_on_exception: 在出异常信息时是否保存
:param larger_better: monitor 的值是否时越大越好
:param only_state_dict: 保存模型时是否只保存 state_dict model_save_fn 不为 None 该参数无意义
:param model_save_fn: 个性化的保存函数当触发保存操作时就调用这个函数这个函数应当接受一个文件夹作为参数不返回任何东西
如果传入了 model_save_fn 函数fastNLP 将不再进行模型相关的保存在多卡场景下我们只在 rank 0 上会运行该函数
:param kwargs:
"""
@property
def save_fn_name(self):
return 'save'
@property
def callback_name(self):
"""
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态
:return:
"""
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
@property
def folder_prefix(self):
return 'trainer'

View File

@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback):
请在函数内完成对模型的保存
:param model_load_fn: 加载 model 的函数 model_save_fn 必须同时不为空本函数的输入为一个已经创建好的文件夹没有输出
请在函数内完成对模型的加载
:param delete_after_train: 加载了最佳模型之后是否删掉模型
:param delete_after_train: 训练结束后是否删掉模型
"""
if model_load_fn is not None:
assert callable(model_load_fn), "`model_load_fn` must be a callable object."

View File

@ -133,17 +133,18 @@ class Evaluator:
self.driver.barrier()
def run(self, num_eval_batch_per_dl: int = -1) -> Dict:
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
"""
返回一个字典类型的数据其中key为metric的名字value为对应metric的结果
如果存在多个metric一个dataloader的情况key的命名规则是
metric_indicator_name#metric_name
如果存在多个数据集一个metric的情况key的命名规则是
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
如果存在多个metric多个dataloader的情况key的命名规则是
metric_indicator_name#metric_name#dataloader_name
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据-1 为测试所有数据
其中 metric_indicator_name 可能不存在
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据-1 为测试所有数据
:return:
"""
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
@ -157,7 +158,6 @@ class Evaluator:
assert self.driver.has_test_dataloaders()
metric_results = {}
self.reset()
evaluate_context = self.driver.get_evaluate_context()
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')

View File

@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.envs import rank_zero_call
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME
@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger):
self.driver.set_deterministic_dataloader(self.dataloader)
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint)
reproducible=self.callback_manager.has_trainer_checkpoint)
self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver)
@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger):
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")
if self.evaluator is not None and num_eval_sanity_batch > 0:
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
self.on_sanity_check_begin()
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
self.on_sanity_check_end(sanity_check_res)
@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger):
:param folder: 保存模型的地址
:param only_state_dict: 是否只保存模型的 `state_dict`
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它
"""
@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger):
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
model_load_fn: Optional[Callable] = None, **kwargs):
"""
加载模型
:param folder: 读取 model 的文件夹默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件 model_load_fn 不为空时
直接将该 folder 传递到 model_load_fn
:param only_state_dict: 要读取的文件中是否仅包含模型权重 model_load_fn 不为 None 该参数无意义
:param model_load_fn: callable 的函数接受一个 folder 作为参数不返回任何内容
:param kwargs:
:return:
"""
self.on_load_model()
self.driver.barrier()
if not isinstance(folder, (io.BytesIO, BinaryIO)):
@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger):
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
r"""
用于断点重训的保存函数;
用于断点重训 Trainer 的保存函数;
:param folder:
:param only_state_dict:
:param model_save_fn:
:param kwargs:
:return:
"""
self.driver.barrier()
@ -594,7 +610,7 @@ class Trainer(TrainerEventTrigger):
r"""
用于断点重训的加载函数
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的因此可能存在一种情况用户只希望加载一个断点重训的状态而在之后不再进行断点重训的
保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleIterator
保存在这种情况下dataloader sampler 就不一定会被替换成我们的 ReproducibleSampler
注意我们目前不支持单卡到多卡的断点重训

View File

@ -24,6 +24,7 @@ class _FDataSet:
对Dataset的封装主要是修改dataset的__getitem__函数增加返回下标idx值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
中调用dataset的方法
"""
def __init__(self, dataset) -> None:
self.dataset = dataset
@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader):
提供给使用pytorch框架的DataLoader函数若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作用户也可以通过
提供的方法调节设置collate_fn的若干参数
"""
def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
@ -184,7 +186,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str] = None)\
input_fields: Union[List, str, None] = None) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
"""
传入dataset或者data_bundle后将其处理返回相对应的FdataLoader实例化对象
@ -221,6 +223,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
if input_fields:
dl.set_input(*input_fields)
return dl
@ -233,16 +236,20 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
if input_fields:
dl_bundle[name].set_input(*input_fields)
return dl_bundle
@ -269,6 +276,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
if input_fields:
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle
@ -282,17 +290,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy)
if input_fields:
dl_bundle[name].set_input(*input_fields)
return dl_bundle

View File

@ -8,9 +8,8 @@ __all__ = [
import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from functools import partial
import warnings
import numpy as np
from threading import Thread
@ -197,6 +196,20 @@ class DataSet:
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
def __setitem__(self, key, value):
assert isinstance(key, int) and key<len(self)
assert isinstance(value, Instance) or isinstance(value, Mapping)
ins_keys = set(value.keys())
ds_keys = set(self.get_field_names())
if len(ins_keys - ds_keys) != 0:
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.")
if len(ds_keys - ins_keys) != 0:
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.")
for field_name, field in self.field_arrays.items():
field[key] = value[field_name]
def __getattribute__(self, item):
return object.__getattribute__(self, item)
@ -813,6 +826,3 @@ class DataSet:
self.collate_fns.set_input(*field_names)
class IterableDataset:
pass

View File

@ -46,9 +46,6 @@ class FieldArray:
def __setitem__(self, idx: int, val: Any):
assert isinstance(idx, int)
if idx == -1:
idx = len(self) - 1
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}"
self.content[idx] = val
def get(self, indices: Union[int, List[int]]):
@ -79,7 +76,7 @@ class FieldArray:
def split(self, sep: str = None, inplace: bool = True):
r"""
依次对自身的元素使用.split()方法应该只有当本field的元素为str时该方法才有用将返回值
依次对自身的元素使用.split()方法应该只有当本field的元素为str时该方法才有用
:param sep: 分割符如果为None则直接调用str.split()
:param inplace: 如果为True则将新生成值替换本field否则返回list

View File

@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from io import BytesIO
import json
__all__ = [
'Driver'
@ -48,10 +49,13 @@ class Driver(ABC):
不同 gpu 上出现重复 'unrepeatdist' 表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据允许不同 gpu batch 的数量不一致其中 trainer kwargs 的参数 `use_dist_sampler` True 该值为 "dist"
否则为 None evaluator 中的 kwargs 的参数 `use_dist_sampler` True 该值为 "unrepeatdist"否则为 None
注意当 dist ReproducibleIterator, RandomBatchSampler 是断点重训加载时 driver.load 函数在调用
dist str 或者 None trainer 在初始化时调用该函数
:param reproducible: 如果为 False 不要做任何考虑如果为 True 需要保证返回的 dataloader 可以保存当前的迭代状态使得
可以可以加载
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) 此外
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的
dataloader 如果 dist 为空 reproducible False可直接返回原对象
"""
if dist is None and reproducible is False:
@ -68,9 +72,12 @@ class Driver(ABC):
def set_sampler_epoch(self, dataloader, cur_epoch_idx):
r"""
对于分布式的 sampler例如 torch DistributedSampler其需要在每一个 epoch 前设置随机数种子来保证每一个进程上的 shuffle 是一样的
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler
:param dataloader: 需要设置 epoch dataloader
:param cur_epoch_idx: 当前是第几个 epoch
"""
@abstractmethod
def train_step(self, batch):
"""
@ -444,13 +451,14 @@ class Driver(ABC):
exc_type, exc_value, exc_traceback_obj = sys.exc_info()
_write_exc_info = {
'exc_type': exc_type,
'exc_value': exc_value,
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'global_rank': getattr(self, "global_rank", None),
'rank': self.get_local_rank(),
'exc_type': str(exc_type.__name__),
'exc_value': str(exc_value),
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'exc_global_rank': getattr(self, "global_rank", None),
'exc_local_rank': self.get_local_rank(),
}
sys.stderr.write(str(_write_exc_info)+"\n")
sys.stderr.write("\nException info:\n")
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n")
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
for pid in self._pids:

View File

@ -3,7 +3,7 @@ from typing import Optional, Union
from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleIterator
from fastNLP.core.samplers import ReproducibleSampler
if _NEED_IMPORT_JITTOR:
import jittor
@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver):
def test_step(self, batch):
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
pass

View File

@ -3,7 +3,7 @@ from typing import Dict, Union
from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
if _NEED_IMPORT_JITTOR:
import jittor
@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self):
return False
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现
if isinstance(dist, ReproducibleBatchSampler):
if isinstance(dist, RandomBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
raise NotImplementedError
dataloader.batch_sampler.sampler = dist
if reproducible:
raise NotImplementedError
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
return dataloader
else:
# TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler = RandomBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last

View File

@ -21,6 +21,8 @@ from fastNLP.core.utils import (
is_in_paddle_dist,
)
from fastNLP.core.samplers import (
RandomBatchSampler,
ReproducibleSampler,
ReproducibleIterator,
RandomSampler,
UnrepeatedDistributedSampler,
@ -318,7 +320,7 @@ class PaddleFleetDriver(PaddleDriver):
def test_step(self, batch):
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \

View File

@ -11,7 +11,11 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
ReproducibleIterator,
re_instantiate_sampler,
)
from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE:

View File

@ -23,15 +23,16 @@ from fastNLP.core.drivers.torch_driver.utils import (
ForwardState,
_MODE_PARAMETER,
reset_seed,
replace_sampler
replace_sampler,
replace_batch_sampler
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
from fastNLP.core.samplers import re_instantiate_sampler
class TorchDDPDriver(TorchDriver):
@ -445,25 +446,52 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
if isinstance(dist, ReproducibleIterator):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None,
reproducible: bool = False):
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed因为如果用户使用的是 TorchDDPDriver那么其在 Trainer 初始化的时候就已经调用了该函数;
dist = re_instantiate_sampler(dist)
if isinstance(dist, RandomBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist)
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
# trainer, evaluator
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
if isinstance(dist, RandomBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)
return dataloader
# trainer
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True那么此时其是否进行断点重训不影响这里的行为
if isinstance(args.sampler, ReproducibleIterator):
if isinstance(args.batch_sampler, RandomBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
@ -477,21 +505,23 @@ class TorchDDPDriver(TorchDriver):
shuffle=args.shuffle,
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
)
# todo 这个你写个todo吧有两个角度第一个是dataloader即使检测到sampler是我们reproducible也不能直接set_distributeds; 第二个如果是单卡的也需要替换sampler乃至切换sampler的状态方式之前多卡现在切换成单卡运行
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, sampler)
# evaluator
elif dist == "unrepeatdist":
args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler(
dataset=args.dataset,
shuffle=args.shuffle,
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)
else:
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank

View File

@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
"""
# # 首先将所有的都移动到cpu上并且连续防止有 pickle 出问题
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if device is None:
device = torch.cuda.current_device()
if _TORCH_GREATER_EQUAL_1_8:
objs = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(objs, obj)
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话所有tensor都在当前卡上
return objs
if device is None:
device = torch.cuda.current_device()
group = group if group is not None else torch.distributed.group.WORLD
data = convert_to_tensors(obj, device=device)
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group)

View File

@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None:
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
@ -39,10 +39,13 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
if device < 0 and device != -1:
if device < 0:
if device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num:
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)]
elif device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
else:
device = torch.device(f"cuda:{device}")
elif isinstance(device, Sequence):
device = list(set(device))
@ -62,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
if not isinstance(device, List):
return TorchSingleDriver(model, device, **kwargs)
else:
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter"
"`driver` as `TorchDDPDriver`.")
return TorchDDPDriver(model, device, **kwargs)

View File

@ -13,9 +13,8 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.log import logger
from fastNLP.core.samplers import re_instantiate_sampler
class TorchSingleDriver(TorchDriver):
@ -130,20 +129,26 @@ class TorchSingleDriver(TorchDriver):
else:
return self._test_step(batch)
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, RandomBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)
if reproducible:
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.sampler, ReproducibleIterator):
if isinstance(args.batch_sampler, RandomBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
else:
batch_sampler = ReproducibleBatchSampler(
if reproducible:
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last

View File

@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator
class TorchDriver(Driver):
@ -143,8 +143,6 @@ class TorchDriver(Driver):
:param filepath: 保存到哪个文件夹
:param only_state_dict: 是否只保存权重
:param model_save_fn:
:return:
"""
model = self.unwrap_model()
@ -184,10 +182,10 @@ class TorchDriver(Driver):
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
# 1. sampler 的状态,因为我们支持 resume training即精确恢复到具体的一个 batch
# 首先 pytorch 的 DataLoader 一定会有 sampler另一方面我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`
# 首先 pytorch 的 DataLoader 一定会有 sampler另一方面我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
@ -247,25 +245,25 @@ class TorchDriver(Driver):
# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.is_distributed():
raise RuntimeError(
"It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
"`RandomBatchSampler` or `ReproducibleIterator`.")
else:
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if not isinstance(sampler, ReproducibleBatchSampler):
if not isinstance(sampler, RandomBatchSampler):
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
@ -293,7 +291,7 @@ class TorchDriver(Driver):
@staticmethod
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed
with ``seed_everything(seed, workers=True)``.
See also the PyTorch documentation on

View File

@ -33,7 +33,7 @@ class TorchBackend(Backend):
if dist.is_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
tensor = fastnlp_torch_all_gather(tensor)
if isinstance(tensor[0], torch.Tensor):
tensor = torch.stack(tensor)
# 第一步, aggregate结果
@ -68,59 +68,6 @@ class TorchBackend(Backend):
def get_scalar(self, tensor) -> float:
return tensor.item()
@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if group is None:
group = dist.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = dist.get_world_size(group)
dist.barrier(group=group)
# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = torch.nn.functional.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result
def tensor2numpy(self, tensor) -> np.array:
"""
将对应的tensor转为numpy对象

View File

@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
class Element:
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
def __init__(self, name, value: float, aggregate_method, backend: Backend):
self.name = name
self.init_value = value
self.aggregate_method = aggregate_method
self.name = name
if backend == 'auto':
raise RuntimeError("You have to specify the backend.")
raise RuntimeError(f"You have to specify the backend for Element:{self.name}.")
elif isinstance(backend, AutoBackend):
self.backend = backend
else:
@ -34,20 +34,16 @@ class Element:
自动aggregate对应的元素
"""
self._check_value_initialized()
try:
self._value = self.backend.aggregate(self._value, self.aggregate_method)
except AggregateMethodError as e:
msg = 'If you see this message, please report a bug.'
if self.name and e.should_have_aggregate_method:
msg = f"Element:{self.name} has no specified `aggregate_method`."
elif e.should_have_aggregate_method:
msg = "Element has no specified `aggregate_method`."
elif self.name and not e.should_have_aggregate_method:
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
elif not e.should_have_aggregate_method:
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
f'aggregate_method:{self.aggregate_method}.'
if e.only_warn:
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
logger.warning(msg)
@ -74,6 +70,7 @@ class Element:
return self._value
def get_scalar(self) -> float:
self._check_value_initialized()
return self.backend.get_scalar(self._value)
def fill_value(self, value):
@ -95,7 +92,7 @@ class Element:
def _check_value_when_call(self):
if self.value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
@ -273,9 +270,10 @@ class Element:
"""
try:
if self._value is None:
prefix = f'Element:`{self.name}`' if self.name else 'Element'
prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
return getattr(self._value, item)
except AttributeError as e:
logger.error(f"Element:{self.name} has no `{item}` attribute.")
raise e

View File

@ -35,7 +35,7 @@ class Metric:
def elements(self) -> dict:
return self._elements
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
"""
注册一个 element 对象注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用可以认为该对象即为对应 backend
tensor 直接进行加减乘除计算即可
@ -57,11 +57,9 @@ class Metric:
else:
backend = AutoBackend(backend)
# 当name为None默认为变量取得变量名
if name is None:
name = f'ele_var_{len(self._elements)}'
assert name is not None and name not in self.elements
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend)
self.elements[name] = element
setattr(self, name, element)
return element

View File

@ -216,9 +216,26 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp):
class SpanFPreRecMetric(Metric):
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None,
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro',
beta=1, aggregate_when_get_metric: bool = True,) -> None:
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro',
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None:
r"""
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 支持的标签为"B"(没有label)"B-xxx"(xxx为某种label比如POS中的NN)
在解码时会将相同xxx的认为是同一个label比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据 为None则使用 `pred` 取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据 为None则使用 `target` 取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据为None则使用 `seq_len` 取数据
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes默认为None通过tag_vocab自动判断.
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算例如在POS tagging时传入['NN']则不会计算'NN'个label
:param bool only_gross: 是否只计算总的f1, precision, recall的值如果为False不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec
:param str f_type: `micro` `macro` . `micro` :通过先计算总体的TPFN和FP的数量再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall然后做平均各类别f的权重相同
:param float beta: f_beta分数 :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率若为1则两者平等若为2则召回率权重高于精确率
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']其中 auto 表示根据实际调用 Metric.update()
函数时传入的参数决定具体的 backend 一般情况下直接使用 'auto' 即可
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric
backend 不支持分布式时该参数无意义
"""
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
@ -249,16 +266,25 @@ class SpanFPreRecMetric(Metric):
self.only_gross = only_gross
self.tag_vocab = tag_vocab
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
self._true_positives = {}
self._false_positives = {}
self._false_negatives = {}
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend)
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend)
def get_metric(self) -> dict:
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
tags.update(self._false_positives.keys())
tags.update(self._true_positives.keys())
f_sum = 0
pre_sum = 0
rec_sum = 0
@ -266,6 +292,9 @@ class SpanFPreRecMetric(Metric):
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
if tp == fn == fp == 0:
continue
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
f_sum += f
pre_sum += pre
@ -284,10 +313,17 @@ class SpanFPreRecMetric(Metric):
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
tp, fn, fp = [], [], []
for val in self._true_positives.values():
tp.append(val.get_scalar())
for val in self._false_negatives.values():
fn.append(val.get_scalar())
for val in self._false_positives.values():
fp.append(val.get_scalar())
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(val.get_scalar() for val in self._true_positives.values()),
sum(val.get_scalar() for val in self._false_negatives.values()),
sum(val.get_scalar() for val in self._false_positives.values()))
sum(tp),
sum(fn),
sum(fp))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec

View File

@ -3,19 +3,30 @@ __all__ = [
'SortedSampler',
'ConstTokenNumSampler',
'ConstantTokenNumSampler',
'UnrepeatedDistributedSampler',
'MixSampler',
'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler',
'ReproducibleIterator',
'ReproducibleSampler',
'RandomSampler',
're_instantiate_sampler'
"SequentialSampler",
"SortedSampler",
'UnrepeatedSampler',
'UnrepeatedRandomSampler',
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",
"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
]
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler

View File

@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict
__all__ = [
'MixSampler',
'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler'

View File

@ -1,6 +1,6 @@
__all__ = [
'BucketedBatchSampler',
"ReproducibleBatchSampler"
"RandomBatchSampler"
]
import math
@ -16,7 +16,7 @@ from fastNLP.core.log import logger
from abc import abstractmethod
class ReproducibleBatchIterator:
class ReproducibleBatchSampler:
@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
@ -42,13 +42,13 @@ class ReproducibleBatchIterator:
pass
class ReproducibleBatchSampler(ReproducibleBatchIterator):
class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象ReproducibleBatchSampler 将首先遍历一边该对象然后将迭代
:param batch_sampler: 可迭代出 数字 数字列表 的可迭代对象RandomBatchSampler 将首先遍历一边该对象然后将迭代
出来的序号暂存起来使用时按照 batch_size batch 大小吐出序号列表
:param batch_size: 每个 batch 的大小是多少
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample 是否丢掉
@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator):
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size
class BucketedBatchSampler(ReproducibleBatchIterator):
class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
"""

View File

@ -1,25 +1,21 @@
from typing import Dict, List
from typing import Dict, List, Union
import math
import numpy as np
from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet
__all__ = [
'ReproducibleIterator',
'ReproducibleSampler',
'RandomSampler',
're_instantiate_sampler'
"SortedSampler",
"SequentialSampler"
]
def re_instantiate_sampler(sampler):
all_attributes = vars(sampler)
return type(sampler)(**all_attributes)
class ReproducibleIterator:
class ReproducibleSampler:
"""
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`用来使我们再断点重训时重新实例化这个 sampler
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`用来使我们再断点重训时重新实例化这个 sampler
或者 batch_sampler注意所有在 init 中初始化的变量都不能含有 _ 下横线作为开头所有不在 init 中设置的变量都必须以下横线开头
"""
@ -47,7 +43,7 @@ class ReproducibleIterator:
pass
class RandomSampler(ReproducibleIterator):
class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""
@ -157,8 +153,8 @@ class RandomSampler(ReproducibleIterator):
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
@ -215,9 +211,132 @@ class RandomSampler(ReproducibleIterator):
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
class SequentialSampler(RandomSampler):
def __init__(self, dataset, dist_mode:str='interval', **kwargs):
"""
按照顺序读取 dataset 在多卡情况下间隔读取例如在两卡情况下卡0取 [0,2,4,..], 卡1取 [1,3,5...]
:param dataset: 实现了 __len__ 方法的数据容器
:param kwargs:
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True说明之前的还没结束只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = self.generate_indices()
if self.pad:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples
for index in indices:
self.num_consumed_samples += self.num_replicas
yield index
self.during_iter = False
self.num_consumed_samples = 0
def generate_indices(self) -> List[int]:
"""
生成随机序列
:return:
"""
return list(range(len(self.dataset)))
def state_dict(self) -> Dict:
states = {
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
}
return states
def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True那么 data_idx 一定是 0
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了则直接将结果重置为0
self.num_consumed_samples = 0
class SortedSampler(SequentialSampler):
def __init__(self, dataset, length:Union[str, List], **kwargs):
"""
dataset 中的数据根据 length 从长到短进行迭代在多卡情况下由于padding 最后一个 sample 可能是最长的那个 sample
:param dataset: 实现了 __len__ 方法的数据容器
:param length: 如果为 List应当与 dataset 有一样的长度表示 dataset 中每个元素的数量仅当传入的 dataset fastNLP
DataSet 时支持传入 str会将该str理解为 dataset field 名称 field 中的元素为 int则认为该值是 sample 的长度
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
def generate_indices(self) -> List[int]:
return self.sorted_indices
def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True说明之前的还没结束只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = self.generate_indices()
if self.pad:
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples
for index in indices:
self.num_consumed_samples += self.num_replicas
yield index
self.during_iter = False
self.num_consumed_samples = 0

View File

@ -7,7 +7,6 @@ __all__ = [
"SortedSampler",
'ConstTokenNumSampler',
"ConstantTokenNumSampler",
"UnrepeatedDistributedSampler",
]
from itertools import chain
@ -18,7 +17,7 @@ import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import SequentialSampler, Sampler, RandomSampler
from torch.utils.data import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data
class UnrepeatedDistributedSampler:
def __init__(self, dataset, shuffle: bool = False, seed: int = 0):
"""
考虑在多卡evaluate的场景下不能重复sample
:param dataset:
:param shuffle:
:param seed:
"""
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed
# 多卡的相关的参数
self.num_replicas = 1
self.rank = 0
self.epoch = -1
def __len__(self):
"""
返回 sampler 一次完整的迭代过程会产生多少个index多卡的情况下只考虑当前rank
:return:
"""
num_common = len(self.dataset)//self.num_replicas
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return self.num_samples
def __iter__(self):
r"""
当前使用num_consumed_samples做法会在交替使用的时候遇到问题
Example:
>>> sampler = RandomSampler()
>>> iter1 = iter(sampler)
>>> iter2 = iter(sampler)
>>> next(iter1)
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化
"""
indices = self.generate_indices()
# subsample
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == len(self)
for index in indices:
yield index
def generate_indices(self) -> List[int]:
"""
生成随机序列
:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
return indices
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def set_distributed(self, num_replicas, rank):
"""
该方法本质上等同于 ddp 情形下的没有完成的初始化应当在初始化该 sampler 本身后立即被调用
:param num_replicas:
:param rank:
:return:
"""
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
return self

View File

@ -0,0 +1,143 @@
__all__ = [
'UnrepeatedSampler',
'UnrepeatedSortedSampler',
'UnrepeatedRandomSampler',
"UnrepeatedSequentialSampler"
]
from typing import List, Union
from fastNLP.core.dataset import DataSet
import numpy as np
class UnrepeatedSampler:
"""
在多卡场景下保证 indice 不重复的 sampler
"""
pass
class UnrepeatedRandomSampler(UnrepeatedSampler):
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
"""
考虑在多卡evaluate的场景下不能重复sample
:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 如果为 True将不进行 shuffle实际上数据会以从长到短的方式输出
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed
# 多卡的相关的参数
self.num_replicas = kwargs.get('num_replicas', 1)
self.rank = kwargs.get('rank', 0)
self.epoch = kwargs.get('epoch', -1)
def __len__(self):
"""
返回 sampler 一次完整的迭代过程会产生多少个index多卡的情况下只考虑当前rank
:return:
"""
num_common = len(self.dataset)//self.num_replicas
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return num_samples
def __iter__(self):
indices = self.generate_indices()
# subsample
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == len(self)
for index in indices:
yield index
def generate_indices(self) -> List[int]:
"""
生成随机序列
:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
return indices
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
def set_distributed(self, num_replicas, rank):
"""
该方法本质上等同于 ddp 情形下的没有完成的初始化应当在初始化该 sampler 本身后立即被调用
:param num_replicas:
:param rank:
:return:
"""
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
return self
class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
def __init__(self, dataset, length:Union[str, List], **kwargs):
"""
dataset 中的数据根据 length 从长到短进行迭代并且保证在多卡场景下数据不重复 sampler 可能导致各个机器上的
batch 数量不完全一致
:param dataset: 实现了 __len__ 方法的数据容器
:param length: 如果为 List应当与 dataset 有一样的长度表示 dataset 中每个元素的数量仅当传入的 dataset fastNLP
DataSet 时支持传入 str会将该str理解为 dataset field 名称 field 中的元素为 int则认为该值是 sample 的长度
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的
def generate_indices(self) -> List[int]:
return self.sorted_indices
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
def __init__(self, dataset, **kwargs):
"""
按照顺序读取 dataset在多卡情况下间隔读取例如在两卡情况下卡0取 [0,2,4,..], 卡1取 [1,3,5...]
:param dataset: 实现了 __len__ 方法的数据容器
:param kwargs:
"""
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs)
def __iter__(self):
indices = self.generate_indices()
indices = indices[self.rank:len(indices):self.num_replicas]
for index in indices:
yield index
def generate_indices(self) -> List[int]:
return list(range(len(self.dataset)))

View File

@ -0,0 +1,42 @@
__all__ = [
're_instantiate_sampler',
'conversion_between_reproducible_and_unrepeated_sampler'
]
from fastNLP.core.samplers.unrepeated_sampler import *
from fastNLP.core.samplers.reproducible_sampler import *
def conversion_between_reproducible_and_unrepeated_sampler(sampler):
"""
sampler 替换成其对应的 reproducible 版本或 unrepeated 版本如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler
:param sampler:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")
def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler)
if new_sampler_class is not None:
return new_sampler_class(**all_attributes)
return type(sampler)(**all_attributes)

View File

@ -96,6 +96,7 @@ class FRichProgress(Progress, metaclass=Singleton):
# start new
self.start()
self.console.show_cursor(show=True)
return self
def set_transient(self, transient: bool = True):
@ -149,6 +150,9 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id)
super().remove_task(task_id)
def start(self) -> None:
super().start()
self.console.show_cursor(show=True)
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess(
@ -161,7 +165,7 @@ if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
TextColumn("{task.fields[post_desc]}", justify="right"),
transient=True,
disable=False,
speed_estimate_period=10
speed_estimate_period=1
)
else:
f_rich_progress = DummyFRichProgress()

View File

@ -44,6 +44,9 @@ __all__ = [
]
def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数的所有参数的名字

View File

@ -153,7 +153,7 @@ def seed_jittor_global_seed(global_seed):
pass
def dump_fastnlp_backend(default:bool = False):
def dump_fastnlp_backend(default:bool = False, backend=None):
"""
fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下
default True则保存的文件为 ~/.fastNLP/envs/default.json
@ -165,6 +165,7 @@ def dump_fastnlp_backend(default:bool = False):
会保存的环境变量为 FASTNLP_BACKEND
:param default:
:param backend: 保存使用的 backend 为哪个值允许的值有 ['torch', 'paddle', 'jittor']如果为 None 则使用环境变量中的值
:return:
"""
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
@ -179,10 +180,16 @@ def dump_fastnlp_backend(default:bool = False):
os.makedirs(os.path.dirname(env_path), exist_ok=True)
envs = {}
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now."
if backend is None:
if FASTNLP_BACKEND in os.environ:
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND]
else:
envs[FASTNLP_BACKEND] = backend
if len(envs):
with open(env_path, 'w', encoding='utf8') as f:
json.dump(fp=f, obj=envs)
print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.")
else:
raise RuntimeError("No backend specified.")

View File

@ -47,7 +47,8 @@ def set_env_on_import_paddle():
# TODO jittor may need set this
def set_env_on_import_jittor():
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH
pass
if 'log_silent' not in os.environ:
os.environ['log_silent'] = '1'
def set_env_on_import():
@ -63,7 +64,7 @@ def set_env_on_import():
# fastNLP 内部使用的一些变量
if FASTNLP_LAUNCH_TIME not in os.environ:
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}"
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}"
os.environ[FASTNLP_LAUNCH_TIME] = cur_time
# 设置对应的值

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
from pathlib import Path
import re
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1(
version,
only_state_dict
):
# def test_model_checkpoint_callback_1(
# model_and_optimizers: TrainerParameters,
# driver='torch_ddp',
# device=[0, 1],
# version=1,
# only_state_dict=True
# ):
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
if version == 0:
callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=1,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1(
]
elif version == 1:
callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc",
save_folder=path,
save_every_n_epochs=3,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1(
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=10,
callbacks=callbacks,
output_from_new_proc="all"
@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1(
if version == 0:
if driver == "torch":
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_10" in all_saved_model_paths
assert "model-epoch_4-batch_123" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_10"]
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"]
assert len(all_saved_model_paths) == 12
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "model-epoch_6" in all_saved_model_paths
assert "model-epoch_9-batch_123" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["model-epoch_6"]
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]
assert len(all_saved_model_paths) == 11
all_state_dicts = [epoch_save_path, step_save_path]
elif version == 1:
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
if driver == "torch":
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 6
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths
assert "last" in all_saved_model_paths
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"]
last_save_path = all_saved_model_paths["last"]
epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 6
@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1(
finally:
synchronize_safe_rm(path)
# pass
pass
if dist.is_initialized():
dist.destroy_process_group()
@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2(
raise NotImplementedError
callbacks = [
CheckpointCallback(
ModelCheckpointCallback(
monitor="acc1",
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=None,
save_last=False,
save_on_exception=NotImplementedError,
@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2(
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch":
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"]
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"]
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"]
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"]
assert len(all_saved_model_paths) == 1
all_state_dicts = [exception_model_path]
@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1(
if version == 0:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=7,
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
save_every_n_batches=123, # 避免和 epoch 的保存重复;
save_topk=None,
save_last=False,
save_on_exception=None,
@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1(
]
elif version == 1:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=2,
save_last=True,
save_on_exception=None,
@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1(
if version == 0:
if driver == "torch":
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_4-batch_123" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"]
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"]
assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_9-batch_123" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"]
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"]
assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path, step_save_path]
elif version == 1:
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2
last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2
last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 3
@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2(
device,
version
):
pytest.skip("Skip transformers test for now.")
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
import transformers
import transformers # 版本4.16.2
import torch
from torchmetrics import Accuracy
from transformers import AutoModelForSequenceClassification
@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2(
if version == 0:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=50,
save_every_n_batches=50,
save_topk=None,
save_last=False,
save_on_exception=None,
@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2(
]
elif version == 1:
callbacks = [
CheckpointCallback(
TrainerCheckpointCallback(
monitor="acc",
is_trainer_checkpoint=True,
save_folder=path,
save_every_n_epochs=None,
save_every_n_global_batches=None,
save_every_n_batches=None,
save_topk=1,
save_last=True,
save_on_exception=None,
@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2(
if version == 0:
if driver == "torch":
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_200" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"]
assert len(all_saved_model_paths) == 4
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths
assert "trainer-epoch_1-batch_100" in all_saved_model_paths
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"]
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"]
assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path]
elif version == 1:
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if driver == "torch":
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1
last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 2
# ddp 下的文件名不同因为同样的数据ddp 用了更少的步数跑完;
else:
assert "last" in all_saved_model_paths
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2(
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 1
last_save_path = all_saved_model_paths["last"]
last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
assert len(all_saved_model_paths) == 2

View File

@ -1,4 +1,4 @@
import unittest
import pytest
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet
@ -17,7 +17,7 @@ class RandomDataset(Dataset):
return 10
class TestPaddle(unittest.TestCase):
class TestPaddle:
def test_init(self):
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})

View File

@ -1,25 +1,25 @@
import unittest
import pytest
from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle
class TestFdl(unittest.TestCase):
class TestFdl:
def test_init_v1(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
# for batch in fdl:
# print(batch)
fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True)
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True)
# for batch in fdl1:
# print(batch)
def test_set_padding(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
ds.set_pad_val("x", val=-1)
fdl = FDataLoader(ds, batch_size=3)
fdl = TorchDataLoader(ds, batch_size=3)
fdl.set_input("x", "y")
for batch in fdl:
print(batch)
@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase):
_dict["Y"].append(ins['y'])
return _dict
fdl = FDataLoader(ds, batch_size=3, as_numpy=True)
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True)
fdl.set_input("x", "y")
fdl.add_collator(collate_fn)
for batch in fdl:
@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase):
def test_get_batch_indices(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
fdl = FDataLoader(ds, batch_size=3, shuffle=True)
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True)
fdl.set_input("y", "x")
for batch in fdl:
print(fdl.get_batch_indices())
@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase):
return object.__getattribute__(self, item)
dataset = _DataSet()
dl = FDataLoader(dataset, batch_size=2, shuffle=True)
dl = TorchDataLoader(dataset, batch_size=2, shuffle=True)
# dl.set_inputs('data', 'labels')
# dl.set_pad_val('labels', val=None)
for batch in dl:
print(batch)
print(dl.get_batch_indices())
def test_prepare_dataloader(self):
def test_prepare_torch_dataloader(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl, FDataLoader)
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
assert isinstance(dl, TorchDataLoader)
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
dbl = DataBundle(datasets={'train': ds, 'val': ds1})
dl_bundle = prepare_dataloader(dbl)
assert isinstance(dl_bundle['train'], FDataLoader)
assert isinstance(dl_bundle['val'], FDataLoader)
dl_bundle = prepare_torch_dataloader(dbl)
assert isinstance(dl_bundle['train'], TorchDataLoader)
assert isinstance(dl_bundle['val'], TorchDataLoader)
ds_dict = {'train_1': ds, 'val': ds1}
dl_dict = prepare_dataloader(ds_dict)
assert isinstance(dl_dict['train_1'], FDataLoader)
assert isinstance(dl_dict['val'], FDataLoader)
dl_dict = prepare_torch_dataloader(ds_dict)
assert isinstance(dl_dict['train_1'], TorchDataLoader)
assert isinstance(dl_dict['val'], TorchDataLoader)
sequence = [ds, ds1]
seq_ds = prepare_dataloader(sequence)
assert isinstance(seq_ds[0], FDataLoader)
assert isinstance(seq_ds[1], FDataLoader)
seq_ds = prepare_torch_dataloader(sequence)
assert isinstance(seq_ds[0], TorchDataLoader)
assert isinstance(seq_ds[1], TorchDataLoader)

View File

@ -1,12 +1,12 @@
import os
import unittest
import pytest
import numpy as np
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
class TestDataSetInit(unittest.TestCase):
class TestDataSetInit:
"""初始化DataSet的办法有以下几种
1) 用dict:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase):
def test_init_v1(self):
# 一维list
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40
def test_init_v2(self):
# 用dict
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40
def test_init_assert(self):
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet([[1, 2, 3, 4]] * 10)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = DataSet(0.00001)
class TestDataSetMethods(unittest.TestCase):
class TestDataSetMethods:
def test_append(self):
dd = DataSet()
for _ in range(3):
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
self.assertEqual(len(dd), 3)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
assert len(dd) == 3
assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
assert dd.field_arrays["y"].content == [[5, 6]] * 3
def test_add_field(self):
dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.add_field("z", [[5, 6]] * 10)
self.assertEqual(len(dd), 10)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
assert len(dd) == 10
assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
assert dd.field_arrays["z"].content == [[5, 6]] * 10
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)
def test_delete_field(self):
@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.delete_field("x")
self.assertFalse("x" in dd.field_arrays)
self.assertTrue("y" in dd.field_arrays)
assert ("x" in dd.field_arrays) == False
assert "y" in dd.field_arrays
def test_delete_instance(self):
dd = DataSet()
@ -80,99 +80,113 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * old_length)
dd.add_field("y", [[1, 2, 3, 4]] * old_length)
dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 1)
assert len(dd) == old_length - 1
dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 2)
assert len(dd) == old_length - 2
def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
self.assertEqual(ins_0["y"], [5, 6])
assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
assert ins_1["x"] == [1, 2, 3, 4]
assert ins_1["y"] == [5, 6]
assert ins_0["x"] == [1, 2, 3, 4]
assert ins_0["y"] == [5, 6]
sub_ds = ds[:10]
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)
assert isinstance(sub_ds, DataSet) == True
assert len(sub_ds) == 10
sub_ds_1 = ds[[10, 0, 2, 3]]
self.assertTrue(isinstance(sub_ds_1, DataSet))
self.assertEqual(len(sub_ds_1), 4)
assert isinstance(sub_ds_1, DataSet) == True
assert len(sub_ds_1) == 4
field_array = ds['x']
self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40)
assert isinstance(field_array, FieldArray) == True
assert len(field_array) == 40
def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_field('i', list(range(len(ds))))
assert ds.get_field('i').content == list(range(len(ds)))
import random
random.shuffle(ds)
import numpy as np
np.random.shuffle(ds)
assert ds.get_field('i').content != list(range(len(ds)))
ins1 = ds[1]
ds[2] = ds[1]
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']
def test_get_item_error(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:]
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"]
def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)
assert len(ds) == 40
ds = DataSet()
self.assertEqual(len(ds), 0)
assert len(ds) == 0
def test_add_fieldarray(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
self.assertEqual(ds['z'].content, [[7, 8]]*40)
assert ds['z'].content == [[7, 8]] * 40
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
ds.add_fieldarray('z', [1, 2, 4])
def test_copy_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.copy_field('x', 'z')
self.assertEqual(ds['x'].content, ds['z'].content)
assert ds['x'].content == ds['z'].content
def test_has_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue(ds.has_field('x'))
self.assertFalse(ds.has_field('z'))
assert ds.has_field('x') == True
assert ds.has_field('z') == False
def test_get_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.get_field('z')
x_array = ds.get_field('x')
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
assert x_array.content == [[1, 2, 3, 4]] * 40
def test_get_all_fields(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_arrays = ds.get_all_fields()
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
assert field_arrays['y'].content == [[5, 6]] * 40
def test_get_field_names(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_names = ds.get_field_names()
self.assertTrue('x' in field_names)
self.assertTrue('y' in field_names)
assert 'x' in field_names
assert 'y' in field_names
def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
assert ("rx" in ds.field_arrays) == True
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
self.assertEqual(ds.field_arrays["y"].content[0], 2)
assert ds.field_arrays["y"].content[0] == 2
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)
assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
# expect no exception raised
@ -192,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase):
def modify_inplace(instance):
instance['words'] = 1
ds.apply(modify_inplace)
# with self.assertRaises(TypeError):
# ds.apply(modify_inplace)
@ -216,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase):
T.apply_more(func_1)
# print(T['c'][0, 1, 2])
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]
res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]
assert list(res["c"]) == [3, 6, 9]
assert list(res["d"]) == [1, 8, 27]
with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_more(func_err_1)
print(e)
with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a")
print(e)
def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
self.assertEqual(len(ds), 20)
assert len(ds) == 20
def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)
assert ("x" in ds) == True
assert ("y" in ds) == True
assert ("z" in ds) == False
def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)
assert ("xx" in ds) == True
assert ("x" in ds) == False
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.rename_field("yyy", "oo")
def test_split(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
d1, d2 = ds.split(0.1)
self.assertEqual(len(d1), len(ds)*0.9)
self.assertEqual(len(d2), len(ds)*0.1)
assert len(d2) == (len(ds) * 0.9)
assert len(d1) == (len(ds) * 0.1)
def test_add_field_v2(self):
ds = DataSet({"x": [3, 4]})
@ -268,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase):
def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))
assert os.path.exists("./my_ds.pkl") == True
ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl")
def test_add_null(self):
ds = DataSet()
with self.assertRaises(RuntimeError) as RE:
with pytest.raises(RuntimeError) as RE:
ds.add_field('test', [])
def test_concat(self):
@ -287,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase):
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
ds3 = ds1.concat(ds2)
self.assertEqual(len(ds3), 20)
assert len(ds3) == 20
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
assert ds1[9]['x'] == [1, 2, 3, 4]
assert ds1[10]['x'] == [4, 3, 2, 1]
ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
# 测试inplace
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@ -304,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase):
ds3 = ds1.concat(ds2, inplace=True)
ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
ds3[0]['x'][0] = 100
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
assert ds1[0]['x'][0] == 100 # 改变copy前的field了
# 测试mapping
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
self.assertEqual(len(ds3), 20)
assert len(ds3) == 20
# 测试忽略掉多余的
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@ -326,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase):
# 测试报错
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})
def test_instance_field_disappear_bug(self):
@ -334,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase):
data.copy_field(field_name='raw_chars', new_field_name='chars')
_data = data[:1]
for field_name in ['raw_chars', 'target', 'chars']:
self.assertTrue(_data.has_field(field_name))
assert _data.has_field(field_name) == True
def test_from_pandas(self):
import pandas as pd
@ -342,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase):
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds = DataSet.from_pandas(df)
print(ds)
self.assertEqual(ds['x'].content, [1, 2, 3])
self.assertEqual(ds['y'].content, [4, 5, 6])
assert ds['x'].content == [1, 2, 3]
assert ds['y'].content == [4, 5, 6]
def test_to_pandas(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
@ -352,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase):
def test_to_csv(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds.to_csv("1.csv")
self.assertTrue(os.path.exists("1.csv"))
assert os.path.exists("1.csv") == True
os.remove("1.csv")
def test_add_collate_fn(self):
@ -360,15 +375,14 @@ class TestDataSetMethods(unittest.TestCase):
def collate_fn(item):
return item
ds.add_collate_fn(collate_fn)
self.assertEqual(len(ds.collate_fns.collators), 2)
ds.add_collate_fn(collate_fn)
def test_get_collator(self):
from typing import Callable
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
collate_fn = ds.get_collator()
self.assertEqual(isinstance(collate_fn, Callable), True)
assert isinstance(collate_fn, Callable) == True
def test_add_seq_len(self):
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
@ -380,7 +394,7 @@ class TestDataSetMethods(unittest.TestCase):
ds.set_target('x')
class TestFieldArrayInit(unittest.TestCase):
class TestFieldArrayInit:
"""
1 如果DataSet使用dict初始化那么在add_field中会构造FieldArray
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@ -428,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase):
# list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])
def test_init_v8(self):
# 二维list
val = np.array([[1, 2], [3, 4]])
@ -436,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase):
fa.append(val)
class TestFieldArray(unittest.TestCase):
class TestFieldArray:
def test_main(self):
fa = FieldArray("x", [1, 2, 3, 4, 5])
self.assertEqual(len(fa), 5)
assert len(fa) == 5
fa.append(6)
self.assertEqual(len(fa), 6)
assert len(fa) == 6
self.assertEqual(fa[-1], 6)
self.assertEqual(fa[0], 1)
assert fa[-1] == 6
assert fa[0] == 1
fa[-1] = 60
self.assertEqual(fa[-1], 60)
assert fa[-1] == 60
self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
assert fa.get(0) == 1
assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
assert list(fa.get([0, 1, 2])) == [1, 2, 3]
def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)
assert isinstance(ans, np.ndarray) == True
assert isinstance(ans[0], np.ndarray) == True
assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
assert ans[1].tolist() == [1, 2, 3, 4, 5]
assert ans.dtype == np.float64
def test_getitem_v2(self):
x = np.random.rand(10, 5)
fa = FieldArray("my_field", x)
indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())
assert a.tolist() == b.tolist()
def test_append(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
assert len(fa) == 3
assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]
def test_pop(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.pop(0)
self.assertEqual(len(fa), 1)
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
assert len(fa) == 1
assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
class TestCase(unittest.TestCase):
class TestCase:
def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields)
assert isinstance(ins.fields, dict) == True
assert ins.fields == fields
ins = Instance(**fields)
self.assertEqual(ins.fields, fields)
assert ins.fields == fields
def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields)
ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields)
assert ins.fields == fields
def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])
assert ins["x"] == [1, 2, 3]
assert ins["y"] == [4, 5, 6]
assert ins["z"] == [1, 1, 1]
def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}

View File

@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from fastNLP.core import synchronize_safe_rm

View File

@ -30,7 +30,7 @@ class SequenceDataSet:
def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerReproducibleBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSamplerRandomBatchSampler
# reproducible 是 True 和 False
# 需要 check 返回的 sampler 和 dataloader 都不同了

View File

@ -118,7 +118,6 @@ class TestAccuracy:
def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'],
metric_kwargs: Dict[str, Any]) -> None:
global pool
print(pool)
if is_ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")

View File

@ -1,5 +1,5 @@
import pytest
import unittest
from collections import Counter
import os, sys
import copy
@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet
set_start_method("spawn", force=True)
@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
print(torch.cuda.device_count())
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
@ -64,15 +64,15 @@ def find_free_network_port() -> int:
return port
@pytest.fixture(scope='class', autouse=True)
def pre_process():
global pool
pool = Pool(processes=NUM_PROCESSES)
master_port = find_free_network_port()
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
yield
pool.close()
pool.join()
# @pytest.fixture(scope='class', autouse=True)
# def pre_process():
# global pool
# pool = Pool(processes=NUM_PROCESSES)
# master_port = find_free_network_port()
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
# yield
# pool.close()
# pool.join()
def _test(local_rank: int,
@ -87,18 +87,19 @@ def _test(local_rank: int,
# dataset 也类似(每个进程有自己的一个)
dataset = copy.deepcopy(dataset)
metric.to(device)
print(os.environ.get("MASTER_PORT", "xx"))
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch即每个 i 取了一个 batch 到自己的 GPU 上)
for i in range(local_rank, len(dataset), world_size):
pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len']
print(tg, seq_len)
metric.update(pred, tg, seq_len)
my_result = metric.get_metric()
print(my_result)
print(sklearn_metric)
assert my_result == sklearn_metric
class SpanFPreRecMetricTest(unittest.TestCase):
global pool
class TestSpanFPreRecMetric:
def test_case1(self):
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans
@ -160,13 +161,11 @@ class SpanFPreRecMetricTest(unittest.TestCase):
-0.5837, 1.0184],
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
-0.9025, 0.0864]]])
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4],
[4, 1, 7, 0, 4, 7]])
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]])
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6])
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5,
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}
assert expect_bio_res == fastnlp_bio_metric.get_metric()
# print(fastnlp_bio_metric.get_metric())
@ -253,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
# print(expected_metric)
metric_value = metric.get_metric()
for key, value in expected_metric.items():
self.assertAlmostEqual(value, metric_value[key], places=5)
np.allclose(value, metric_value[key])
def test_auto_encoding_type_infer(self):
# 检查是否可以自动check encode的类型
@ -270,7 +269,6 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab.add_word('o')
vocabs[encoding_type] = vocab
for e in ['bio', 'bioes', 'bmeso']:
with self.subTest(e=e):
metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
assert metric.encoding_type == e
@ -285,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab = Vocabulary()
for i in range(10):
vocab.add_word(str(i))
with self.assertRaises(Exception):
with pytest.raises(Exception):
metric = SpanFPreRecMetric(vocab)
def test_encoding_type(self):
@ -304,7 +302,6 @@ class SpanFPreRecMetricTest(unittest.TestCase):
vocab.add_word('o')
vocabs[encoding_type] = vocab
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
with self.subTest(e1=e1, e2=e2):
if e1 == e2:
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
else:
@ -312,30 +309,30 @@ class SpanFPreRecMetricTest(unittest.TestCase):
s2.update(set(e1))
if s2 == set(e2):
continue
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
for encoding_type in ['bio', 'bioes', 'bmeso']:
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes')
with self.assertWarns(Warning):
with pytest.warns(Warning):
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso')
vocab = Vocabulary().add_word_lst(list('bmes'))
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso')
def test_case5(self):
global pool
# pool = Pool(NUM_PROCESSES)
# master_port = find_free_network_port()
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
# global pool
pool = Pool(NUM_PROCESSES)
master_port = find_free_network_port()
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
number_labels = 4
# bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
dataset = DataSet({'pred': [torch.FloatTensor(
[[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
dataset = DataSet({'pred': [
torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
-0.3562, -1.4116],
@ -346,8 +343,10 @@ class SpanFPreRecMetricTest(unittest.TestCase):
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
-1.3508, -0.9513],
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
-0.0842, -0.4294]],
-0.0842, -0.4294]]
]),
torch.FloatTensor([
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
-1.4138, -0.8853],
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
@ -359,10 +358,16 @@ class SpanFPreRecMetricTest(unittest.TestCase):
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
-0.5837, 1.0184],
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
-0.9025, 0.0864]]])] * 100,
'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4],
[4, 1, 7, 0, 4, 7]])] * 100,
'seq_len': [[6, 6]] * 100})
-0.9025, 0.0864]]
])
],
'tg': [
torch.LongTensor([[3, 6, 0, 8, 2, 4]]),
torch.LongTensor([[4, 1, 7, 0, 4, 7]])
],
'seq_len': [
[6], [6]
]})
metric_kwargs = {
'tag_vocab': fastnlp_bio_vocab,
'only_gross': False,
@ -372,7 +377,6 @@ class SpanFPreRecMetricTest(unittest.TestCase):
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}
processes = NUM_PROCESSES
print(torch.cuda.device_count())
pool.starmap(
partial(
@ -384,3 +388,5 @@ class SpanFPreRecMetricTest(unittest.TestCase):
),
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)]
)
pool.close()
pool.join()

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest
from itertools import chain
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@ -18,7 +18,7 @@ class TestReproducibleBatchSampler:
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
forward_steps = 3
@ -28,15 +28,15 @@ class TestReproducibleBatchSampler:
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
"sampler_type": "ReproducibleBatchSampler"}
"sampler_type": "RandomBatchSampler"}
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
@ -53,7 +53,7 @@ class TestReproducibleBatchSampler:
# 改变 batch_size
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
@ -99,7 +99,7 @@ class TestReproducibleBatchSampler:
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
@ -111,13 +111,13 @@ class TestReproducibleBatchSampler:
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
# 2. 断点重训,重新生成一个 dataloader
# 不改变 batch_size
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
@ -416,7 +416,6 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('num_replica', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
# TODO 两个 rank 上的长度是要在同一个bucket的
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replica*batch_size > num_samples:

View File

@ -1,18 +1,14 @@
import unittest
from itertools import product
import numpy as np
import pytest
from functools import partial
from array import array
from itertools import chain
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
class TestRandomSamplerYh(unittest.TestCase):
class TestRandomSamplerYh:
def test_init(self):
# 测试能否正确初始化
dataset = TorchNormalDataset(num_of_data=100)
@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase):
dataset = TorchNormalDataset(num_of_data=100)
sampler = RandomSampler(dataset)
for i in sampler:
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
sampler.set_distributed(1, 0)
break
@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase):
dataset = TorchNormalDataset(num_of_data=100)
sampler = RandomSampler(dataset, shuffle=False)
sampler.set_distributed(num_replicas=2, rank=0, pad=False)
self.assertEqual(len(sampler), 50)
assert len(sampler)==50
count = 0
for i in sampler:
self.assertEqual(i%2, 0)
assert i%2==0
count += 1
self.assertEqual(count, 50)
assert count == 50
sampler.set_distributed(num_replicas=2, rank=1, pad=False)
self.assertEqual(len(sampler), 50)
assert len(sampler)==50
count = 0
for i in sampler:
self.assertEqual(i%2, 1)
assert i%2==1
count += 1
self.assertEqual(count, 50)
assert count==50
dataset = TorchNormalDataset(num_of_data=101)
sampler = RandomSampler(dataset, shuffle=False)
sampler.set_distributed(num_replicas=2, rank=0, pad=True)
self.assertEqual(len(sampler), 51)
assert len(sampler)==51
count = 0
for i in sampler:
self.assertEqual(i%2, 0)
assert i%2==0
count += 1
self.assertEqual(count, 51)
assert count == 51
sampler.set_distributed(num_replicas=2, rank=1, pad=True)
self.assertEqual(len(sampler), 51)
assert len(sampler) == 51
count = 0
for i in sampler:
if i!=0:
self.assertEqual(i%2, 1)
assert i%2==1
count += 1
self.assertEqual(count, 51)
assert count == 51
def test_state_dict_check_length(self):
dataset = TorchNormalDataset(num_of_data=100)
@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase):
states = sampler.state_dict()
new_ds = TorchNormalDataset(num_of_data=10)
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
new_sampler = RandomSampler(new_ds)
new_sampler.load_state_dict(states)
@ -85,14 +81,14 @@ class TestRandomSamplerYh(unittest.TestCase):
new_sampler = RandomSampler(new_ds)
new_sampler.load_state_dict(states)
def test_state_dict(self):
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('pre_shuffle', [True, False])
@pytest.mark.parametrize('post_shuffle', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
num_samples = 100
dataset = TorchNormalDataset(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
lst = [0]+np.random.randint(1, num_samples, size=3).tolist()
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
lst):
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
sampler = RandomSampler(dataset, shuffle=pre_shuffle)
sampler.set_epoch(0)
already_numbers = set()
@ -101,7 +97,7 @@ class TestRandomSamplerYh(unittest.TestCase):
already_numbers.add(j)
if i == num_consumed_samples:
break
self.assertEqual(len(already_numbers), num_consumed_samples)
assert len(already_numbers) == num_consumed_samples
states = sampler.state_dict()
@ -109,32 +105,36 @@ class TestRandomSamplerYh(unittest.TestCase):
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
self.assertNotIn(i, already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
for i in new_sampler:
self.assertNotIn(i, other_rank_number)
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
self.assertNotIn(i, already_numbers)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
def test_state_dict_2(self):
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('pre_shuffle', [True, False])
@pytest.mark.parametrize('post_shuffle', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100
dataset = TorchNormalDataset(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist()
# lst = [30]
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
lst):
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
already_numbers = set()
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
sampler.set_distributed(num_replicas=2, rank=0)
@ -152,7 +152,7 @@ class TestRandomSamplerYh(unittest.TestCase):
already_numbers.add(j)
if i == num_consumed_samples:
break
self.assertEqual(len(already_numbers), num_consumed_samples*2)
assert len(already_numbers) == num_consumed_samples*2
states = sampler.state_dict()
@ -160,7 +160,7 @@ class TestRandomSamplerYh(unittest.TestCase):
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
self.assertNotIn(i, already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
@ -168,16 +168,20 @@ class TestRandomSamplerYh(unittest.TestCase):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
for i in new_sampler:
self.assertNotIn(i, other_rank_number)
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
self.assertNotIn(i, already_numbers)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
class TestRandomSampler(unittest.TestCase):
class TestRandomSampler:
# 测试单卡;
def test_seed_work_when_shuffle_is_true(self):
data_length = 100
@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase):
...
class DatasetWithVaryLength:
def __init__(self, num_of_data=100, reverse=False):
self.data = np.arange(num_of_data)
if reverse:
self.data = self.data[::-1]
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
class TestSortedSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = SortedSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = SortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
samplers.append(sampler)
# 保证顺序是没乱的
already_seen_index = set()
for sampler in samplers:
larger_count = 0 # 这里为 0 就可以因为最后补充的index一定是比较大的数。
prev_index = float('inf')
cur_set = set()
seen_in_other_rank = 0
for index in sampler:
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
cur_set.add(index)
larger_count += int(index <= prev_index)
prev_index = index
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
already_seen_index.update(cur_set)
indexes = list(chain(*samplers))
indexes = set(indexes)
if pad:
assert indexes == set(range(num_of_data))
else:
assert len(indexes) <= num_of_data
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, num_consumed_samples):
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j<max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples
states = sampler.state_dict()
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i < max(already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i >= max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller<=1 if pad else smaller==0
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
# lst = [30]
already_numbers = set()
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j<=max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples*2
states = sampler.state_dict()
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i < max(already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i>=max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller <= 1 if pad else smaller == 0
class TestSequentialSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = SequentialSampler(data)
indexes = list(sampler)
assert indexes==list(range(num_of_data))
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = SequentialSampler(dataset=data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
samplers.append(sampler)
# 保证顺序是没乱的
already_seen_index = set()
for idx, sampler in enumerate(samplers):
larger_count = 1
prev_index = float('inf')
cur_set = set()
seen_in_other_rank = 0
for index in sampler:
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
cur_set.add(index)
larger_count += int(index >= prev_index)
prev_index = index
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
already_seen_index.update(cur_set)
indexes = list(chain(*samplers))
indexes = set(indexes)
if pad:
assert indexes == set(range(num_of_data))
else:
assert len(indexes) <= num_of_data
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, num_consumed_samples):
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
sampler = SequentialSampler(dataset=dataset)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j>max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples
states = sampler.state_dict()
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i > max(already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i <= max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=rank # 因为pad可能重复
assert smaller<=1 if pad else smaller==0
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
# lst = [30]
already_numbers = set()
sampler = SequentialSampler(dataset=dataset)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j>max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = SequentialSampler(dataset=dataset)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples*2
states = sampler.state_dict()
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i > max(already_numbers)
assert i not in already_numbers
# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i<max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller <= rank if pad else smaller == 0

View File

@ -0,0 +1,104 @@
from itertools import chain
import pytest
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
class DatasetWithVaryLength:
def __init__(self, num_of_data=100):
self.data = list(range(num_of_data))
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
class TestUnrepeatedSampler:
@pytest.mark.parametrize('shuffle', [True, False])
def test_single(self, shuffle):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedRandomSampler(data, shuffle)
indexes = set(sampler)
assert indexes==set(range(num_of_data))
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes==set(range(num_of_data))
class TestUnrepeatedSortedSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSortedSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)
# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('inf')
for index in sampler:
assert index <= prev_index
prev_index = index
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data # 不同卡之间没有交叉
indexes = set(indexes)
assert indexes==set(range(num_of_data))
class TestUnrepeatedSequentialSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSequentialSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data))
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)
# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('-inf')
for index in sampler:
assert index>=prev_index
prev_index = index
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes == set(range(num_of_data))