mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
fix conflict
This commit is contained in:
commit
3a3c38a44e
131
README.md
131
README.md
@ -6,4 +6,133 @@
|
||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
|
||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
|
||||
|
||||
dev0.8.0正在开发中
|
||||
fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是快速实现NLP任务以及构建复杂模型。
|
||||
|
||||
fastNLP具有如下的特性:
|
||||
|
||||
- 统一的Tabular式数据容器,简化数据预处理过程;
|
||||
- 内置多种数据集的Loader和Pipe,省去预处理代码;
|
||||
- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等;
|
||||
- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
|
||||
- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
|
||||
- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。
|
||||
|
||||
## 安装指南
|
||||
|
||||
fastNLP 依赖以下包:
|
||||
|
||||
+ numpy>=1.14.2
|
||||
+ torch>=1.0.0
|
||||
+ tqdm>=4.28.1
|
||||
+ nltk>=3.4.1
|
||||
+ requests
|
||||
+ spacy
|
||||
+ prettytable>=0.7.2
|
||||
|
||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
|
||||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装
|
||||
|
||||
```shell
|
||||
pip install fastNLP
|
||||
python -m spacy download en
|
||||
```
|
||||
|
||||
|
||||
## fastNLP教程
|
||||
中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
|
||||
|
||||
### 快速入门
|
||||
|
||||
- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
|
||||
|
||||
### 详细使用教程
|
||||
|
||||
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
|
||||
- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
|
||||
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
|
||||
- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
|
||||
- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
|
||||
- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
|
||||
- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
|
||||
- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
|
||||
- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
|
||||
- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
|
||||
|
||||
### 扩展教程
|
||||
|
||||
- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
|
||||
- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
|
||||
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
|
||||
|
||||
|
||||
## 内置组件
|
||||
|
||||
大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
|
||||
|
||||
以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:
|
||||
|
||||
|
||||
![](./docs/source/figures/text_classification.png)
|
||||
|
||||
fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
|
||||
(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
|
||||
|
||||
与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><b> 类型 </b></td>
|
||||
<td><b> 功能 </b></td>
|
||||
<td><b> 例子 </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td> encoder </td>
|
||||
<td> 将输入编码为具有具有表示能力的向量 </td>
|
||||
<td> Embedding, RNN, CNN, Transformer, ...
|
||||
</tr>
|
||||
<tr>
|
||||
<td> decoder </td>
|
||||
<td> 将具有某种表示意义的向量解码为需要的输出形式 </td>
|
||||
<td> MLP, CRF, ... </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
## 项目结构
|
||||
|
||||
<div align=center><img width="450" height="350" src="./docs/source/figures/workflow.png"/></div>
|
||||
|
||||
|
||||
|
||||
fastNLP的大致工作流程如上图所示,而项目结构如下:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><b> fastNLP </b></td>
|
||||
<td> 开源的自然语言处理库 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b> fastNLP.core </b></td>
|
||||
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b> fastNLP.models </b></td>
|
||||
<td> 实现了一些完整的神经网络模型 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b> fastNLP.modules </b></td>
|
||||
<td> 实现了用于搭建神经网络模型的诸多组件 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b> fastNLP.embeddings </b></td>
|
||||
<td> 实现了将序列index转为向量序列的功能,包括读取预训练embedding等 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b> fastNLP.io </b></td>
|
||||
<td> 实现了读写功能,包括数据读入与预处理,模型读写,数据与模型自动下载等 </td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<hr>
|
||||
|
||||
*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
|
||||
|
@ -4,7 +4,8 @@ __all__ = [
|
||||
'EventsList',
|
||||
'Filter',
|
||||
'CallbackManager',
|
||||
'CheckpointCallback',
|
||||
'ModelCheckpointCallback',
|
||||
'TrainerCheckpointCallback',
|
||||
'choose_progress_callback',
|
||||
'ProgressCallback',
|
||||
'RichCallback',
|
||||
@ -16,7 +17,7 @@ __all__ = [
|
||||
from .callback import Callback
|
||||
from .callback_events import EventsList, Events, Filter
|
||||
from .callback_manager import CallbackManager
|
||||
from .checkpoint_callback import CheckpointCallback
|
||||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
|
||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
|
||||
from .lr_scheduler_callback import LRSchedCallback
|
||||
from .load_best_model_callback import LoadBestModelCallback
|
||||
|
@ -8,7 +8,7 @@ __all__ = [
|
||||
|
||||
from .callback_events import Events
|
||||
from .callback import Callback
|
||||
from .checkpoint_callback import CheckpointCallback
|
||||
from .checkpoint_callback import TrainerCheckpointCallback
|
||||
from .progress_callback import ProgressCallback, choose_progress_callback
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
@ -98,7 +98,7 @@ class CallbackManager:
|
||||
:return:
|
||||
"""
|
||||
for each_callback in self.class_callbacks:
|
||||
if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint:
|
||||
if isinstance(each_callback, TrainerCheckpointCallback):
|
||||
self._has_trainer_checkpoint = True
|
||||
self.dissect_one_callback(each_callback)
|
||||
|
||||
@ -210,7 +210,7 @@ class CallbackManager:
|
||||
each_callback.on_load_checkpoint(trainer, None)
|
||||
|
||||
@property
|
||||
def has_trainer_chechpoint(self) -> bool:
|
||||
def has_trainer_checkpoint(self) -> bool:
|
||||
return self._has_trainer_checkpoint
|
||||
|
||||
@_transfer
|
||||
|
@ -1,12 +1,13 @@
|
||||
import os
|
||||
from typing import Union, Optional, Callable, Dict, Sequence
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
|
||||
__all__ = [
|
||||
'CheckpointCallback'
|
||||
'ModelCheckpointCallback',
|
||||
'TrainerCheckpointCallback'
|
||||
]
|
||||
import os
|
||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping
|
||||
from pathlib import Path
|
||||
from abc import ABC
|
||||
import sys
|
||||
|
||||
|
||||
import fastNLP
|
||||
from .callback import Callback, Filter
|
||||
@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME
|
||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
|
||||
|
||||
class CanItemDataType(ABC):
|
||||
"""
|
||||
检测可以进行传输的对象。
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||
if cls is CanItemDataType:
|
||||
item = getattr(subclass, 'item', None)
|
||||
return callable(item)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
||||
class CheckpointCallback(Callback):
|
||||
"""
|
||||
1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情;
|
||||
2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar;
|
||||
3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1",
|
||||
"acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例;
|
||||
4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样;
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
monitor,
|
||||
is_trainer_checkpoint: Optional[bool] = False,
|
||||
|
||||
save_folder: Optional[Union[str, Path]] = None,
|
||||
|
||||
save_every_n_epochs: Optional[int] = None,
|
||||
save_every_n_global_batches: Optional[int] = None,
|
||||
save_every_n_batches: Optional[int] = None,
|
||||
save_last: bool = True,
|
||||
save_topk: Optional[int] = None,
|
||||
save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None,
|
||||
|
||||
larger_better: bool = True,
|
||||
only_state_dict: bool = True,
|
||||
|
||||
model_save_fn: Optional[Callable] = None,
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
if monitor is None and save_topk is not None:
|
||||
@ -51,9 +54,6 @@ class CheckpointCallback(Callback):
|
||||
if monitor is not None and not isinstance(monitor, str):
|
||||
raise ValueError("Parameter `monitor` should be of 'str' type.")
|
||||
|
||||
if not isinstance(is_trainer_checkpoint, bool):
|
||||
raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.")
|
||||
|
||||
if save_folder is None:
|
||||
logger.warning(
|
||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.")
|
||||
@ -67,15 +67,15 @@ class CheckpointCallback(Callback):
|
||||
if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1:
|
||||
raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.")
|
||||
|
||||
# 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的
|
||||
# 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样;
|
||||
self._filter_every_n_epochs = Filter(every=save_every_n_epochs)
|
||||
else:
|
||||
save_every_n_epochs = sys.maxsize # 使得没有数字可以整除
|
||||
|
||||
if save_every_n_global_batches is not None:
|
||||
if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1:
|
||||
if save_every_n_batches is not None:
|
||||
if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1:
|
||||
raise ValueError(
|
||||
"parameter save_every_n_global_batches should be an int and greater than or equal to 1.")
|
||||
self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches)
|
||||
"parameter save_every_n_batches should be an int and greater than or equal to 1.")
|
||||
else:
|
||||
save_every_n_batches = sys.maxsize # 使得没有数字可以整除
|
||||
|
||||
if save_topk is not None:
|
||||
if not isinstance(save_topk, int) or save_topk < 1:
|
||||
@ -89,12 +89,12 @@ class CheckpointCallback(Callback):
|
||||
if not issubclass(exception, BaseException):
|
||||
raise TypeError("Each exception in parameter `save_on_exception` can only be "
|
||||
"`BaseException` type.")
|
||||
|
||||
else:
|
||||
save_on_exception = []
|
||||
self.monitor = monitor
|
||||
self.is_trainer_checkpoint = is_trainer_checkpoint
|
||||
self.save_folder = Path(save_folder)
|
||||
self.save_every_n_epochs = save_every_n_epochs
|
||||
self.save_every_n_global_batches = save_every_n_global_batches
|
||||
self.save_every_n_batches = save_every_n_batches
|
||||
self.save_last = save_last
|
||||
self.save_topk = save_topk
|
||||
self.larger_better = larger_better
|
||||
@ -107,7 +107,7 @@ class CheckpointCallback(Callback):
|
||||
self._topk_model = {}
|
||||
self._topn = 0 # 表示目前已经保存了几个最好的模型;
|
||||
|
||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个
|
||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的
|
||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次
|
||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为;
|
||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key;
|
||||
@ -115,76 +115,83 @@ class CheckpointCallback(Callback):
|
||||
|
||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候,
|
||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的;
|
||||
self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
|
||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])
|
||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行;
|
||||
synchronize_mkdir(self.log_filepath)
|
||||
synchronize_mkdir(self.timestamp_path)
|
||||
|
||||
def on_validate_end(self, trainer, validate_res):
|
||||
self._save_topk(trainer, validate_res)
|
||||
|
||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"):
|
||||
self._save_every_n_epochs(trainer)
|
||||
self._save_last(trainer)
|
||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0:
|
||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}'
|
||||
self.save(trainer, folder_name=folder_name)
|
||||
if self.save_last:
|
||||
folder_name = f'{self.folder_prefix}-last'
|
||||
self.save(trainer, folder_name=folder_name)
|
||||
|
||||
def on_train_batch_end(self, trainer):
|
||||
self._save_every_n_global_batches(trainer)
|
||||
if trainer.global_forward_batches % self.save_every_n_batches == 0:
|
||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}'
|
||||
self.save(trainer, folder_name=folder_name)
|
||||
|
||||
def on_exception(self, trainer, exception: BaseException):
|
||||
if self.save_on_exception is not None and exception.__class__ in self.save_on_exception:
|
||||
folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None)
|
||||
folder = folder + f"_{exception.__class__.__name__}"
|
||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder)
|
||||
if exception.__class__ in self.save_on_exception:
|
||||
folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \
|
||||
f'exception_{exception.__class__.__name__}'
|
||||
self.save(trainer=trainer, folder_name=folder_name)
|
||||
|
||||
def on_sanity_check_end(self, trainer, sanity_check_res):
|
||||
# 主要核对一下 monitor 是否存在。
|
||||
self._get_validate_metric(sanity_check_res)
|
||||
|
||||
def on_save_checkpoint(self, trainer) -> Dict:
|
||||
"""
|
||||
我们需要保存 CheckpointCallback 内部的几个 filter 的状态;
|
||||
保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。
|
||||
topk_model的状态
|
||||
_real_monitor的值
|
||||
"""
|
||||
|
||||
states = {}
|
||||
if self.save_every_n_epochs is not None:
|
||||
states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict()
|
||||
if self.save_every_n_global_batches is not None:
|
||||
states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict()
|
||||
states["real_monitor"] = self._real_monitor
|
||||
states['timestamp_path'] = str(self.timestamp_path.absolute())
|
||||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType,
|
||||
function=lambda x:x.item())
|
||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk
|
||||
states['_real_monitor'] = self._real_monitor
|
||||
return states
|
||||
|
||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]):
|
||||
if self.save_every_n_epochs is not None:
|
||||
self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"])
|
||||
if self.save_every_n_global_batches is not None:
|
||||
self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"])
|
||||
timestamp_path = states['timestamp_path']
|
||||
if not os.path.exists(timestamp_path):
|
||||
logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to "
|
||||
f" {self.timestamp_path.absolute()}.")
|
||||
else:
|
||||
logger.info(f"Resume to save in path: {timestamp_path}.")
|
||||
self.timestamp_path = Path(timestamp_path)
|
||||
_topk_model = states['_topk_model']
|
||||
save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk'])
|
||||
if save_topk is not None and self.save_topk is not None:
|
||||
assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \
|
||||
f"as {save_topk}."
|
||||
self._topk_model.update(self._topk_model)
|
||||
self._real_monitor = states["real_monitor"]
|
||||
|
||||
def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"):
|
||||
if self.save_every_n_epochs is not None:
|
||||
if self.is_trainer_checkpoint:
|
||||
_fn_every_n_epochs = trainer.save
|
||||
else:
|
||||
_fn_every_n_epochs = trainer.save_model
|
||||
_fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None)
|
||||
_fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs)
|
||||
_fn_every_n_epochs()
|
||||
|
||||
def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"):
|
||||
if self.save_every_n_global_batches is not None:
|
||||
if self.is_trainer_checkpoint:
|
||||
_fn_every_n_global_batches = trainer.save
|
||||
else:
|
||||
_fn_every_n_global_batches = trainer.save_model
|
||||
_fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None)
|
||||
_fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches)
|
||||
_fn_every_n_global_batches()
|
||||
|
||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict):
|
||||
"""
|
||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。
|
||||
|
||||
:param trainer:
|
||||
:param validate_res:
|
||||
:return:
|
||||
"""
|
||||
if self.save_topk is not None:
|
||||
_metric_value = self._get_validate_metric(validate_res)
|
||||
_saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value)
|
||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
|
||||
f"-{self._real_monitor}_{_metric_value}"
|
||||
|
||||
_should_save = False
|
||||
if self._topn < self.save_topk:
|
||||
self._topk_model[_saved_name] = _metric_value
|
||||
self._topk_model[folder_name] = _metric_value
|
||||
self._topn += 1
|
||||
_should_save = True
|
||||
else:
|
||||
@ -192,39 +199,27 @@ class CheckpointCallback(Callback):
|
||||
key=lambda x: self._topk_model[x])
|
||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \
|
||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]):
|
||||
self._topk_model[_saved_name] = _metric_value
|
||||
self._topk_model[folder_name] = _metric_value
|
||||
_should_save = True
|
||||
self._topk_model.pop(_least_valuable_model)
|
||||
synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model))
|
||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model))
|
||||
|
||||
assert len(self._topk_model) == self.save_topk == self._topn
|
||||
|
||||
if _should_save:
|
||||
self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name)
|
||||
self.save(trainer, folder_name=folder_name)
|
||||
|
||||
def _save_last(self, trainer: "fastNLP.Trainer"):
|
||||
if self.save_last:
|
||||
self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last")
|
||||
|
||||
def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None,
|
||||
substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None):
|
||||
# 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者
|
||||
# epoch-batch-monitor-monitor_value;
|
||||
if substitute_folder is None:
|
||||
folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric))
|
||||
else:
|
||||
folder = self.log_filepath.joinpath(substitute_folder)
|
||||
def save(self, trainer, folder_name):
|
||||
"""
|
||||
执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。
|
||||
|
||||
:param trainer:
|
||||
:param folder_name:
|
||||
:return:
|
||||
"""
|
||||
folder = self.timestamp_path.joinpath(folder_name)
|
||||
synchronize_mkdir(folder)
|
||||
|
||||
# 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数;
|
||||
if substitute_fn is not None:
|
||||
_fn = substitute_fn
|
||||
else:
|
||||
if self.is_trainer_checkpoint:
|
||||
_fn = trainer.save
|
||||
else:
|
||||
_fn = trainer.save_model
|
||||
_fn = getattr(trainer, self.save_fn_name)
|
||||
_fn(
|
||||
folder=folder,
|
||||
only_state_dict=self.only_state_dict,
|
||||
@ -243,18 +238,48 @@ class CheckpointCallback(Callback):
|
||||
self._real_monitor = use_monitor
|
||||
return value
|
||||
|
||||
def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False,
|
||||
metric: Optional[Union[int, float]] = None) -> str:
|
||||
"""
|
||||
获取当前保存模型的真正地名字;
|
||||
metric 参数仅当 mode 为 'topk' 时起作用;
|
||||
"""
|
||||
cur_epoch_idx = trainer.cur_epoch_idx
|
||||
global_forward_batches = trainer.global_forward_batches
|
||||
_other = ""
|
||||
if topk:
|
||||
_other = f"_{metric}"
|
||||
return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}"
|
||||
@property
|
||||
def folder_prefix(self):
|
||||
raise NotImplementedError("The `folder_prefix` is not specified")
|
||||
|
||||
@property
|
||||
def save_fn_name(self):
|
||||
raise NotImplementedError("The `save_fn_name` is not specified.")
|
||||
|
||||
|
||||
class ModelCheckpointCallback(CheckpointCallback):
|
||||
"""
|
||||
保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下
|
||||
|
||||
- save_folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
|
||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
|
||||
- model-last/ # 最后一个 epoch 的保存
|
||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
|
||||
- model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
|
||||
model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。
|
||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。
|
||||
|
||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。
|
||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
|
||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
|
||||
:param save_every_n_epochs: 多少个 epoch 保存一次。
|
||||
:param save_every_n_batches: 多少个 batch 保存一次。
|
||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
|
||||
:param save_topk: 保存 monitor 结果 topK 个。
|
||||
:param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。
|
||||
:param larger_better: monitor 的值是否时越大越好。
|
||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。
|
||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
|
||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
|
||||
:param kwargs:
|
||||
"""
|
||||
@property
|
||||
def save_fn_name(self):
|
||||
return 'save_model'
|
||||
|
||||
@property
|
||||
def callback_name(self):
|
||||
@ -262,6 +287,55 @@ class CheckpointCallback(Callback):
|
||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
|
||||
:return:
|
||||
"""
|
||||
return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}"
|
||||
return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
|
||||
|
||||
@property
|
||||
def folder_prefix(self):
|
||||
return 'model'
|
||||
|
||||
|
||||
class TrainerCheckpointCallback(CheckpointCallback):
|
||||
"""
|
||||
保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下
|
||||
|
||||
- save_folder/
|
||||
- YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的
|
||||
- trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型
|
||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型
|
||||
- trainer-last/ # 最后一个 epoch 的保存
|
||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。
|
||||
- trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名
|
||||
|
||||
model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。
|
||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。
|
||||
|
||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
|
||||
的那个作为 monitor 。
|
||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
|
||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。
|
||||
:param save_every_n_epochs: 多少个 epoch 保存一次。
|
||||
:param save_every_n_batches: 多少个 batch 保存一次。
|
||||
:param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。
|
||||
:param save_topk: 保存 monitor 结果 topK 个。
|
||||
:param save_on_exception: 在出异常信息时,是否保存。
|
||||
:param larger_better: monitor 的值是否时越大越好。
|
||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。
|
||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。
|
||||
如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。
|
||||
:param kwargs:
|
||||
"""
|
||||
@property
|
||||
def save_fn_name(self):
|
||||
return 'save'
|
||||
|
||||
@property
|
||||
def callback_name(self):
|
||||
"""
|
||||
通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态;
|
||||
:return:
|
||||
"""
|
||||
return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}"
|
||||
|
||||
@property
|
||||
def folder_prefix(self):
|
||||
return 'trainer'
|
||||
|
@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback):
|
||||
请在函数内完成对模型的保存。
|
||||
:param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出,
|
||||
请在函数内完成对模型的加载。
|
||||
:param delete_after_train: 在加载了最佳模型之后是否删掉模型。
|
||||
:param delete_after_train: 在训练结束后是否删掉模型。
|
||||
"""
|
||||
if model_load_fn is not None:
|
||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object."
|
||||
|
@ -133,17 +133,18 @@ class Evaluator:
|
||||
|
||||
self.driver.barrier()
|
||||
|
||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict:
|
||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
|
||||
"""
|
||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。
|
||||
如果存在多个metric,一个dataloader的情况,key的命名规则是
|
||||
metric_indicator_name#metric_name
|
||||
如果存在多个数据集,一个metric的情况,key的命名规则是
|
||||
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
|
||||
如果存在多个metric,多个dataloader的情况,key的命名规则是
|
||||
metric_indicator_name#metric_name#dataloader_name
|
||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。
|
||||
如果存在多个metric,一个dataloader的情况,key的命名规则是
|
||||
metric_indicator_name#metric_name
|
||||
如果存在多个数据集,一个metric的情况,key的命名规则是
|
||||
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
|
||||
如果存在多个metric,多个dataloader的情况,key的命名规则是
|
||||
metric_indicator_name#metric_name#dataloader_name
|
||||
其中 metric_indicator_name 可能不存在。
|
||||
|
||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
|
||||
@ -157,7 +158,6 @@ class Evaluator:
|
||||
assert self.driver.has_test_dataloaders()
|
||||
|
||||
metric_results = {}
|
||||
|
||||
self.reset()
|
||||
evaluate_context = self.driver.get_evaluate_context()
|
||||
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')
|
||||
|
@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver
|
||||
from fastNLP.core.drivers.utils import choose_driver
|
||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME
|
||||
|
||||
@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger):
|
||||
self.driver.set_deterministic_dataloader(self.dataloader)
|
||||
|
||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
|
||||
reproducible=self.callback_manager.has_trainer_chechpoint)
|
||||
reproducible=self.callback_manager.has_trainer_checkpoint)
|
||||
|
||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
|
||||
self.on_after_trainer_initialized(self.driver)
|
||||
@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger):
|
||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")
|
||||
|
||||
if self.evaluator is not None and num_eval_sanity_batch > 0:
|
||||
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
|
||||
self.on_sanity_check_begin()
|
||||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
|
||||
self.on_sanity_check_end(sanity_check_res)
|
||||
@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
:param folder: 保存模型的地址;
|
||||
:param only_state_dict: 是否只保存模型的 `state_dict`;
|
||||
:param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
|
||||
:param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
|
||||
:param kwargs: 一些 driver 的保存模型的函数的参数另有其它;
|
||||
"""
|
||||
|
||||
@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
|
||||
model_load_fn: Optional[Callable] = None, **kwargs):
|
||||
"""
|
||||
加载模型
|
||||
|
||||
:param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时,
|
||||
直接将该 folder 传递到 model_load_fn 中。
|
||||
:param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。
|
||||
:param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.on_load_model()
|
||||
self.driver.barrier()
|
||||
if not isinstance(folder, (io.BytesIO, BinaryIO)):
|
||||
@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger):
|
||||
|
||||
def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
|
||||
r"""
|
||||
用于断点重训的保存函数;
|
||||
用于断点重训 Trainer 的保存函数;
|
||||
|
||||
:param folder:
|
||||
:param only_state_dict:
|
||||
:param model_save_fn:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.driver.barrier()
|
||||
|
||||
@ -594,7 +610,7 @@ class Trainer(TrainerEventTrigger):
|
||||
r"""
|
||||
用于断点重训的加载函数;
|
||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的
|
||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator;
|
||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler;
|
||||
|
||||
注意我们目前不支持单卡到多卡的断点重训;
|
||||
|
||||
|
@ -24,6 +24,7 @@ class _FDataSet:
|
||||
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
|
||||
中调用dataset的方法
|
||||
"""
|
||||
|
||||
def __init__(self, dataset) -> None:
|
||||
self.dataset = dataset
|
||||
|
||||
@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader):
|
||||
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过
|
||||
提供的方法调节设置collate_fn的若干参数。
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size: int = 1,
|
||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
|
||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
|
||||
@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader):
|
||||
|
||||
|
||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
|
||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
|
||||
num_workers: int = 0, collate_fn: Optional[Callable] = None,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
|
||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
|
||||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
|
||||
non_train_batch_size: int = 16, as_numpy: bool = False,
|
||||
input_fields: Union[List, str] = None)\
|
||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
|
||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
|
||||
num_workers: int = 0, collate_fn: Optional[Callable] = None,
|
||||
pin_memory: bool = False, drop_last: bool = False,
|
||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
|
||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
|
||||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
|
||||
non_train_batch_size: int = 16, as_numpy: bool = False,
|
||||
input_fields: Union[List, str, None] = None) \
|
||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
|
||||
"""
|
||||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象
|
||||
|
||||
@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
|
||||
multiprocessing_context=multiprocessing_context, generator=generator,
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
dl.set_input(*input_fields)
|
||||
if input_fields:
|
||||
dl.set_input(*input_fields)
|
||||
return dl
|
||||
|
||||
elif isinstance(ds_or_db, DataBundle):
|
||||
@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
|
||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
|
||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
|
||||
multiprocessing_context=multiprocessing_context, generator=generator,
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
else:
|
||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
|
||||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
|
||||
shuffle=shuffle, sampler=non_train_sampler,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
|
||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
|
||||
multiprocessing_context=multiprocessing_context, generator=generator,
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
dl_bundle[name].set_input(*input_fields)
|
||||
if input_fields:
|
||||
dl_bundle[name].set_input(*input_fields)
|
||||
return dl_bundle
|
||||
|
||||
elif isinstance(ds_or_db, Sequence):
|
||||
@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
)
|
||||
for dl in dl_bundle:
|
||||
dl.set_input(*input_fields)
|
||||
if input_fields:
|
||||
for dl in dl_bundle:
|
||||
dl.set_input(*input_fields)
|
||||
return dl_bundle
|
||||
|
||||
elif isinstance(ds_or_db, Mapping):
|
||||
@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
|
||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
|
||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
|
||||
multiprocessing_context=multiprocessing_context, generator=generator,
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
else:
|
||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
|
||||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
|
||||
shuffle=shuffle, sampler=non_train_sampler,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
|
||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
|
||||
multiprocessing_context=multiprocessing_context, generator=generator,
|
||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
|
||||
prefetch_factor=prefetch_factor,
|
||||
persistent_workers=persistent_workers,
|
||||
as_numpy=as_numpy)
|
||||
|
||||
dl_bundle[name].set_input(*input_fields)
|
||||
if input_fields:
|
||||
dl_bundle[name].set_input(*input_fields)
|
||||
|
||||
return dl_bundle
|
||||
else:
|
||||
|
@ -8,9 +8,8 @@ __all__ = [
|
||||
|
||||
import _pickle as pickle
|
||||
from copy import deepcopy
|
||||
from typing import Optional, List, Callable, Union, Dict, Any
|
||||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from threading import Thread
|
||||
@ -197,6 +196,20 @@ class DataSet:
|
||||
else:
|
||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
assert isinstance(key, int) and key<len(self)
|
||||
assert isinstance(value, Instance) or isinstance(value, Mapping)
|
||||
ins_keys = set(value.keys())
|
||||
ds_keys = set(self.get_field_names())
|
||||
|
||||
if len(ins_keys - ds_keys) != 0:
|
||||
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.")
|
||||
if len(ds_keys - ins_keys) != 0:
|
||||
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.")
|
||||
|
||||
for field_name, field in self.field_arrays.items():
|
||||
field[key] = value[field_name]
|
||||
|
||||
def __getattribute__(self, item):
|
||||
return object.__getattribute__(self, item)
|
||||
|
||||
@ -813,6 +826,3 @@ class DataSet:
|
||||
self.collate_fns.set_input(*field_names)
|
||||
|
||||
|
||||
class IterableDataset:
|
||||
pass
|
||||
|
||||
|
@ -46,9 +46,6 @@ class FieldArray:
|
||||
|
||||
def __setitem__(self, idx: int, val: Any):
|
||||
assert isinstance(idx, int)
|
||||
if idx == -1:
|
||||
idx = len(self) - 1
|
||||
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}"
|
||||
self.content[idx] = val
|
||||
|
||||
def get(self, indices: Union[int, List[int]]):
|
||||
@ -79,7 +76,7 @@ class FieldArray:
|
||||
|
||||
def split(self, sep: str = None, inplace: bool = True):
|
||||
r"""
|
||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
|
||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。
|
||||
|
||||
:param sep: 分割符,如果为None则直接调用str.split()。
|
||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。
|
||||
|
@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
__all__ = [
|
||||
'Driver'
|
||||
@ -48,10 +49,13 @@ class Driver(ABC):
|
||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
|
||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
|
||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
|
||||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
|
||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;
|
||||
|
||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
|
||||
可以可以加载。
|
||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
|
||||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
|
||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的
|
||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
|
||||
"""
|
||||
if dist is None and reproducible is False:
|
||||
@ -68,9 +72,12 @@ class Driver(ABC):
|
||||
def set_sampler_epoch(self, dataloader, cur_epoch_idx):
|
||||
r"""
|
||||
对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的;
|
||||
dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。
|
||||
|
||||
:param dataloader: 需要设置 epoch 的 dataloader 。
|
||||
:param cur_epoch_idx: 当前是第几个 epoch;
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, batch):
|
||||
"""
|
||||
@ -444,13 +451,14 @@ class Driver(ABC):
|
||||
|
||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info()
|
||||
_write_exc_info = {
|
||||
'exc_type': exc_type,
|
||||
'exc_value': exc_value,
|
||||
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
|
||||
'global_rank': getattr(self, "global_rank", None),
|
||||
'rank': self.get_local_rank(),
|
||||
'exc_type': str(exc_type.__name__),
|
||||
'exc_value': str(exc_value),
|
||||
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
|
||||
'exc_global_rank': getattr(self, "global_rank", None),
|
||||
'exc_local_rank': self.get_local_rank(),
|
||||
}
|
||||
sys.stderr.write(str(_write_exc_info)+"\n")
|
||||
sys.stderr.write("\nException info:\n")
|
||||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n")
|
||||
|
||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
|
||||
for pid in self._pids:
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
||||
|
||||
from .jittor_driver import JittorDriver
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.samplers import ReproducibleIterator
|
||||
from fastNLP.core.samplers import ReproducibleSampler
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver):
|
||||
def test_step(self, batch):
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
pass
|
||||
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, Union
|
||||
from .jittor_driver import JittorDriver
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver):
|
||||
def is_distributed(self):
|
||||
return False
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
# reproducible 的相关功能暂时没有实现
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
raise NotImplementedError
|
||||
dataloader.batch_sampler = dist_sample
|
||||
if isinstance(dist, ReproducibleIterator):
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
raise NotImplementedError
|
||||
dataloader.batch_sampler.sampler = dist
|
||||
|
||||
if reproducible:
|
||||
raise NotImplementedError
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
|
||||
return dataloader
|
||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
|
||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
|
||||
return dataloader
|
||||
else:
|
||||
# TODO
|
||||
batch_sampler = ReproducibleBatchSampler(
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler=dataloader.batch_sampler,
|
||||
batch_size=dataloader.batch_sampler.batch_size,
|
||||
drop_last=dataloader.drop_last
|
||||
|
@ -21,6 +21,8 @@ from fastNLP.core.utils import (
|
||||
is_in_paddle_dist,
|
||||
)
|
||||
from fastNLP.core.samplers import (
|
||||
RandomBatchSampler,
|
||||
ReproducibleSampler,
|
||||
ReproducibleIterator,
|
||||
RandomSampler,
|
||||
UnrepeatedDistributedSampler,
|
||||
@ -318,7 +320,7 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
def test_step(self, batch):
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
# 暂时不支持iterableDataset
|
||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \
|
||||
|
@ -11,7 +11,11 @@ from fastNLP.core.utils import (
|
||||
get_paddle_device_id,
|
||||
paddle_move_data_to_device,
|
||||
)
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler
|
||||
from fastNLP.core.samplers import (
|
||||
ReproducibleBatchSampler,
|
||||
ReproducibleIterator,
|
||||
re_instantiate_sampler,
|
||||
)
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
|
@ -23,15 +23,16 @@ from fastNLP.core.drivers.torch_driver.utils import (
|
||||
ForwardState,
|
||||
_MODE_PARAMETER,
|
||||
reset_seed,
|
||||
replace_sampler
|
||||
replace_sampler,
|
||||
replace_batch_sampler
|
||||
)
|
||||
from fastNLP.core.drivers.utils import distributed_open_proc
|
||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params
|
||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \
|
||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
|
||||
from fastNLP.core.samplers import re_instantiate_sampler
|
||||
|
||||
|
||||
class TorchDDPDriver(TorchDriver):
|
||||
@ -445,25 +446,52 @@ class TorchDDPDriver(TorchDriver):
|
||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
if isinstance(dist, ReproducibleIterator):
|
||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||
dist = re_instantiate_sampler(dist)
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None,
|
||||
reproducible: bool = False):
|
||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
|
||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
dist.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
|
||||
# trainer, evaluator
|
||||
if dist is None:
|
||||
if reproducible:
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
|
||||
"control.")
|
||||
else:
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
dist = re_instantiate_sampler(dist)
|
||||
return replace_sampler(dataloader, dist)
|
||||
return dataloader
|
||||
# trainer
|
||||
elif dist == "dist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
|
||||
if isinstance(args.sampler, ReproducibleIterator):
|
||||
if isinstance(args.batch_sampler, RandomBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
batch_sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
elif isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
@ -477,21 +505,23 @@ class TorchDDPDriver(TorchDriver):
|
||||
shuffle=args.shuffle,
|
||||
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
|
||||
)
|
||||
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank,
|
||||
pad=True
|
||||
)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
sampler = UnrepeatedDistributedSampler(
|
||||
dataset=args.dataset,
|
||||
shuffle=args.shuffle,
|
||||
)
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
|
||||
elif not isinstance(args.sampler, UnrepeatedSampler):
|
||||
sampler = UnrepeatedSequentialSampler(
|
||||
dataset=args.dataset
|
||||
)
|
||||
else:
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
sampler.set_distributed(
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank
|
||||
|
@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
|
||||
"""
|
||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题
|
||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
objs = [None for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather_object(objs, obj)
|
||||
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上
|
||||
return objs
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
group = group if group is not None else torch.distributed.group.WORLD
|
||||
data = convert_to_tensors(obj, device=device)
|
||||
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group)
|
||||
|
@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
|
||||
# world_size 和 rank
|
||||
if FASTNLP_BACKEND_LAUNCH in os.environ:
|
||||
if device is not None:
|
||||
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
|
||||
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
|
||||
"up your script. And we will directly get the local device via "
|
||||
"`os.environ['LOCAL_RANK']`.")
|
||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
|
||||
@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
if device < 0 and device != -1:
|
||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
|
||||
if device >= _could_use_device_num:
|
||||
if device < 0:
|
||||
if device != -1:
|
||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
|
||||
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)]
|
||||
elif device >= _could_use_device_num:
|
||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
|
||||
device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, Sequence):
|
||||
device = list(set(device))
|
||||
for each in device:
|
||||
@ -62,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
|
||||
if not isinstance(device, List):
|
||||
return TorchSingleDriver(model, device, **kwargs)
|
||||
else:
|
||||
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
|
||||
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
|
||||
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter"
|
||||
"`driver` as `TorchDDPDriver`.")
|
||||
return TorchDDPDriver(model, device, **kwargs)
|
||||
|
@ -13,9 +13,8 @@ __all__ = [
|
||||
from .torch_driver import TorchDriver
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import re_instantiate_sampler
|
||||
|
||||
|
||||
class TorchSingleDriver(TorchDriver):
|
||||
@ -130,25 +129,31 @@ class TorchSingleDriver(TorchDriver):
|
||||
else:
|
||||
return self._test_step(batch)
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None,
|
||||
reproducible: bool = False):
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
|
||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
|
||||
if isinstance(dist, RandomBatchSampler):
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleIterator):
|
||||
elif isinstance(dist, ReproducibleSampler):
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.batch_sampler, RandomBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
elif isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
|
||||
if reproducible:
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.sampler, ReproducibleIterator):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
else:
|
||||
batch_sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
return dataloader
|
||||
|
||||
|
@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
|
||||
from fastNLP.envs import rank_zero_call
|
||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler
|
||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator
|
||||
|
||||
|
||||
class TorchDriver(Driver):
|
||||
@ -143,8 +143,6 @@ class TorchDriver(Driver):
|
||||
|
||||
:param filepath: 保存到哪个文件夹;
|
||||
:param only_state_dict: 是否只保存权重;
|
||||
:param model_save_fn:
|
||||
|
||||
:return:
|
||||
"""
|
||||
model = self.unwrap_model()
|
||||
@ -184,10 +182,10 @@ class TorchDriver(Driver):
|
||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;
|
||||
|
||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
|
||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
|
||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
|
||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
|
||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
|
||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif dataloader_args.sampler:
|
||||
sampler = dataloader_args.sampler
|
||||
@ -247,25 +245,25 @@ class TorchDriver(Driver):
|
||||
|
||||
# 3. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
|
||||
sampler = dataloader_args.sampler
|
||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
|
||||
# 说明这里需要使用 ReproduceSampler 来弄一下了
|
||||
if self.is_distributed():
|
||||
raise RuntimeError(
|
||||
"It is not allowed to use single device checkpoint retraining before but ddp now.")
|
||||
sampler = ReproducibleBatchSampler(
|
||||
batch_sampler=sampler,
|
||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler):
|
||||
sampler = dataloader_args.batch_sampler
|
||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
|
||||
sampler = dataloader_args.sampler
|
||||
elif self.is_distributed():
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
|
||||
"`RandomBatchSampler` or `ReproducibleIterator`.")
|
||||
else:
|
||||
sampler = RandomBatchSampler(
|
||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
|
||||
batch_size=dataloader_args.batch_size,
|
||||
drop_last=dataloader_args.drop_last
|
||||
)
|
||||
sampler.load_state_dict(states['sampler_states'])
|
||||
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||
if not isinstance(sampler, ReproducibleBatchSampler):
|
||||
if not isinstance(sampler, RandomBatchSampler):
|
||||
if dataloader_args.drop_last:
|
||||
batch_idx_in_epoch = len(
|
||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
||||
@ -293,7 +291,7 @@ class TorchDriver(Driver):
|
||||
|
||||
@staticmethod
|
||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed
|
||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed
|
||||
with ``seed_everything(seed, workers=True)``.
|
||||
|
||||
See also the PyTorch documentation on
|
||||
|
@ -33,7 +33,7 @@ class TorchBackend(Backend):
|
||||
if dist.is_initialized():
|
||||
if method is None:
|
||||
raise AggregateMethodError(should_have_aggregate_method=True)
|
||||
tensor = self._gather_all(tensor)
|
||||
tensor = fastnlp_torch_all_gather(tensor)
|
||||
if isinstance(tensor[0], torch.Tensor):
|
||||
tensor = torch.stack(tensor)
|
||||
# 第一步, aggregate结果
|
||||
@ -68,59 +68,6 @@ class TorchBackend(Backend):
|
||||
def get_scalar(self, tensor) -> float:
|
||||
return tensor.item()
|
||||
|
||||
@staticmethod
|
||||
def _gather_all(result, group: Optional[Any] = None) -> List:
|
||||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
||||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
|
||||
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: list with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
"""
|
||||
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
|
||||
# convert tensors to contiguous format
|
||||
result = result.contiguous()
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
dist.barrier(group=group)
|
||||
|
||||
# if the tensor is scalar, things are easy
|
||||
if result.ndim == 0:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 1. Gather sizes of all tensors
|
||||
local_size = torch.tensor(result.shape, device=result.device)
|
||||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
|
||||
dist.all_gather(local_sizes, local_size, group=group)
|
||||
max_size = torch.stack(local_sizes).max(dim=0).values
|
||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
|
||||
|
||||
# 2. If shapes are all the same, then do a simple gather:
|
||||
if all_sizes_equal:
|
||||
return _simple_gather_all_tensors(result, group, world_size)
|
||||
|
||||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
|
||||
pad_dims = []
|
||||
pad_by = (max_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
result_padded = torch.nn.functional.pad(result, pad_dims)
|
||||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_result, result_padded, group)
|
||||
for idx, item_size in enumerate(local_sizes):
|
||||
slice_param = [slice(dim_size) for dim_size in item_size]
|
||||
gathered_result[idx] = gathered_result[idx][slice_param]
|
||||
return gathered_result
|
||||
|
||||
def tensor2numpy(self, tensor) -> np.array:
|
||||
"""
|
||||
将对应的tensor转为numpy对象
|
||||
|
@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
|
||||
|
||||
|
||||
class Element:
|
||||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None):
|
||||
def __init__(self, name, value: float, aggregate_method, backend: Backend):
|
||||
self.name = name
|
||||
self.init_value = value
|
||||
self.aggregate_method = aggregate_method
|
||||
self.name = name
|
||||
if backend == 'auto':
|
||||
raise RuntimeError("You have to specify the backend.")
|
||||
raise RuntimeError(f"You have to specify the backend for Element:{self.name}.")
|
||||
elif isinstance(backend, AutoBackend):
|
||||
self.backend = backend
|
||||
else:
|
||||
@ -34,20 +34,16 @@ class Element:
|
||||
自动aggregate对应的元素
|
||||
|
||||
"""
|
||||
self._check_value_initialized()
|
||||
try:
|
||||
self._value = self.backend.aggregate(self._value, self.aggregate_method)
|
||||
except AggregateMethodError as e:
|
||||
msg = 'If you see this message, please report a bug.'
|
||||
if self.name and e.should_have_aggregate_method:
|
||||
msg = f"Element:{self.name} has no specified `aggregate_method`."
|
||||
elif e.should_have_aggregate_method:
|
||||
msg = "Element has no specified `aggregate_method`."
|
||||
elif self.name and not e.should_have_aggregate_method:
|
||||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \
|
||||
f'aggregate_method:{self.aggregate_method}.'
|
||||
elif not e.should_have_aggregate_method:
|
||||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \
|
||||
f'aggregate_method:{self.aggregate_method}.'
|
||||
if e.only_warn:
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
logger.warning(msg)
|
||||
@ -74,6 +70,7 @@ class Element:
|
||||
return self._value
|
||||
|
||||
def get_scalar(self) -> float:
|
||||
self._check_value_initialized()
|
||||
return self.backend.get_scalar(self._value)
|
||||
|
||||
def fill_value(self, value):
|
||||
@ -95,7 +92,7 @@ class Element:
|
||||
|
||||
def _check_value_when_call(self):
|
||||
if self.value is None:
|
||||
prefix = f'Element:`{self.name}`' if self.name else 'Element'
|
||||
prefix = f'Element:`{self.name}`'
|
||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
|
||||
"element, or use it after it being used by the `Metric.compute()` method.")
|
||||
|
||||
@ -273,9 +270,10 @@ class Element:
|
||||
"""
|
||||
try:
|
||||
if self._value is None:
|
||||
prefix = f'Element:`{self.name}`' if self.name else 'Element'
|
||||
prefix = f'Element:`{self.name}`'
|
||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
|
||||
"element, or use it after it being used by the `Metric.compute()` method.")
|
||||
return getattr(self._value, item)
|
||||
except AttributeError as e:
|
||||
logger.error(f"Element:{self.name} has no `{item}` attribute.")
|
||||
raise e
|
||||
|
@ -35,7 +35,7 @@ class Metric:
|
||||
def elements(self) -> dict:
|
||||
return self._elements
|
||||
|
||||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element:
|
||||
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element:
|
||||
"""
|
||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的
|
||||
tensor 直接进行加减乘除计算即可。
|
||||
@ -57,11 +57,9 @@ class Metric:
|
||||
else:
|
||||
backend = AutoBackend(backend)
|
||||
|
||||
# 当name为None,默认为变量取得变量名
|
||||
if name is None:
|
||||
name = f'ele_var_{len(self._elements)}'
|
||||
assert name is not None and name not in self.elements
|
||||
|
||||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name)
|
||||
element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend)
|
||||
self.elements[name] = element
|
||||
setattr(self, name, element)
|
||||
return element
|
||||
|
@ -216,9 +216,26 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
|
||||
class SpanFPreRecMetric(Metric):
|
||||
|
||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None,
|
||||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro',
|
||||
beta=1, aggregate_when_get_metric: bool = True,) -> None:
|
||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
|
||||
only_gross: bool = True, f_type='micro',
|
||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None:
|
||||
r"""
|
||||
|
||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
|
||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
|
||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
|
||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
|
||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
|
||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断.
|
||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label
|
||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec
|
||||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
|
||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
|
||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update()
|
||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。
|
||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric,
|
||||
当 backend 不支持分布式时,该参数无意义。
|
||||
"""
|
||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
|
||||
if f_type not in ('micro', 'macro'):
|
||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
|
||||
@ -249,16 +266,25 @@ class SpanFPreRecMetric(Metric):
|
||||
self.only_gross = only_gross
|
||||
self.tag_vocab = tag_vocab
|
||||
|
||||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None))
|
||||
self._true_positives = {}
|
||||
self._false_positives = {}
|
||||
self._false_negatives = {}
|
||||
for word, _ in tag_vocab:
|
||||
word = word.lower()
|
||||
if word != 'o':
|
||||
word = word[2:]
|
||||
if word in self._true_positives:
|
||||
continue
|
||||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
|
||||
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend)
|
||||
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend)
|
||||
|
||||
def get_metric(self) -> dict:
|
||||
evaluate_result = {}
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._false_negatives.keys())
|
||||
tags.update(set(self._false_positives.keys()))
|
||||
tags.update(set(self._true_positives.keys()))
|
||||
tags.update(self._false_positives.keys())
|
||||
tags.update(self._true_positives.keys())
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
@ -266,6 +292,9 @@ class SpanFPreRecMetric(Metric):
|
||||
tp = self._true_positives[tag].get_scalar()
|
||||
fn = self._false_negatives[tag].get_scalar()
|
||||
fp = self._false_positives[tag].get_scalar()
|
||||
if tp == fn == fp == 0:
|
||||
continue
|
||||
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
|
||||
f_sum += f
|
||||
pre_sum += pre
|
||||
@ -284,10 +313,17 @@ class SpanFPreRecMetric(Metric):
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
tp, fn, fp = [], [], []
|
||||
for val in self._true_positives.values():
|
||||
tp.append(val.get_scalar())
|
||||
for val in self._false_negatives.values():
|
||||
fn.append(val.get_scalar())
|
||||
for val in self._false_positives.values():
|
||||
fp.append(val.get_scalar())
|
||||
f, pre, rec = _compute_f_pre_rec(self.beta_square,
|
||||
sum(val.get_scalar() for val in self._true_positives.values()),
|
||||
sum(val.get_scalar() for val in self._false_negatives.values()),
|
||||
sum(val.get_scalar() for val in self._false_positives.values()))
|
||||
sum(tp),
|
||||
sum(fn),
|
||||
sum(fp))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
|
@ -3,19 +3,30 @@ __all__ = [
|
||||
'SortedSampler',
|
||||
'ConstTokenNumSampler',
|
||||
'ConstantTokenNumSampler',
|
||||
'UnrepeatedDistributedSampler',
|
||||
|
||||
'MixSampler',
|
||||
'InnerSampler',
|
||||
'DopedSampler',
|
||||
'MixSequentialSampler',
|
||||
'PollingSampler',
|
||||
'ReproducibleIterator',
|
||||
|
||||
'ReproducibleSampler',
|
||||
'RandomSampler',
|
||||
're_instantiate_sampler'
|
||||
"SequentialSampler",
|
||||
"SortedSampler",
|
||||
|
||||
'UnrepeatedSampler',
|
||||
'UnrepeatedRandomSampler',
|
||||
"UnrepeatedSortedSampler",
|
||||
"UnrepeatedSequentialSampler",
|
||||
|
||||
"re_instantiate_sampler",
|
||||
"conversion_between_reproducible_and_unrepeated_sampler"
|
||||
]
|
||||
|
||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
|
||||
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
|
||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler
|
||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
|
||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
|
||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
|
||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
|
||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
|
||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict
|
||||
|
||||
__all__ = [
|
||||
'MixSampler',
|
||||
'InnerSampler',
|
||||
'DopedSampler',
|
||||
'MixSequentialSampler',
|
||||
'PollingSampler'
|
||||
|
@ -1,6 +1,6 @@
|
||||
__all__ = [
|
||||
'BucketedBatchSampler',
|
||||
"ReproducibleBatchSampler"
|
||||
"RandomBatchSampler"
|
||||
]
|
||||
|
||||
import math
|
||||
@ -16,7 +16,7 @@ from fastNLP.core.log import logger
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class ReproducibleBatchIterator:
|
||||
class ReproducibleBatchSampler:
|
||||
@abstractmethod
|
||||
def set_distributed(self, num_replicas, rank, pad=True):
|
||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
|
||||
@ -42,13 +42,13 @@ class ReproducibleBatchIterator:
|
||||
pass
|
||||
|
||||
|
||||
class ReproducibleBatchSampler(ReproducibleBatchIterator):
|
||||
class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
|
||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
|
||||
"""
|
||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。
|
||||
|
||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代
|
||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
|
||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
|
||||
:param batch_size: 每个 batch 的大小是多少。
|
||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
|
||||
@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator):
|
||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size
|
||||
|
||||
|
||||
class BucketedBatchSampler(ReproducibleBatchIterator):
|
||||
class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
|
||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
|
||||
"""
|
||||
|
@ -1,25 +1,21 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Union
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
__all__ = [
|
||||
'ReproducibleIterator',
|
||||
'ReproducibleSampler',
|
||||
'RandomSampler',
|
||||
're_instantiate_sampler'
|
||||
"SortedSampler",
|
||||
"SequentialSampler"
|
||||
]
|
||||
|
||||
|
||||
def re_instantiate_sampler(sampler):
|
||||
all_attributes = vars(sampler)
|
||||
return type(sampler)(**all_attributes)
|
||||
|
||||
|
||||
|
||||
class ReproducibleIterator:
|
||||
class ReproducibleSampler:
|
||||
"""
|
||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
|
||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
|
||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。
|
||||
|
||||
"""
|
||||
@ -47,7 +43,7 @@ class ReproducibleIterator:
|
||||
pass
|
||||
|
||||
|
||||
class RandomSampler(ReproducibleIterator):
|
||||
class RandomSampler(ReproducibleSampler):
|
||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
|
||||
"""
|
||||
|
||||
@ -157,8 +153,8 @@ class RandomSampler(ReproducibleIterator):
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
|
||||
"and current dataset."
|
||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
|
||||
f"and current dataset({len(self.dataset)})."
|
||||
self.seed = states['seed']
|
||||
self.epoch = states['epoch']
|
||||
self.num_consumed_samples = states['num_consumed_samples']
|
||||
@ -215,9 +211,132 @@ class RandomSampler(ReproducibleIterator):
|
||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))
|
||||
|
||||
|
||||
class SequentialSampler(RandomSampler):
|
||||
def __init__(self, dataset, dist_mode:str='interval', **kwargs):
|
||||
"""
|
||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param kwargs:
|
||||
"""
|
||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
|
||||
self.num_consumed_samples = 0
|
||||
self.during_iter = True
|
||||
indices = self.generate_indices()
|
||||
|
||||
if self.pad:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.num_consumed_samples:]
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == self.num_left_samples
|
||||
|
||||
for index in indices:
|
||||
self.num_consumed_samples += self.num_replicas
|
||||
yield index
|
||||
self.during_iter = False
|
||||
self.num_consumed_samples = 0
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
"""
|
||||
生成随机序列
|
||||
|
||||
:return:
|
||||
"""
|
||||
return list(range(len(self.dataset)))
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
states = {
|
||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
|
||||
'sampler_type': self.__class__.__name__,
|
||||
'length': len(self.dataset),
|
||||
}
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states: Dict):
|
||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
|
||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
|
||||
"during an unfinished iteration."
|
||||
|
||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
|
||||
f"and current dataset({len(self.dataset)})."
|
||||
self.num_consumed_samples = states['num_consumed_samples']
|
||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
|
||||
self.num_consumed_samples = 0
|
||||
|
||||
|
||||
class SortedSampler(SequentialSampler):
|
||||
def __init__(self, dataset, length:Union[str, List], **kwargs):
|
||||
"""
|
||||
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
|
||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
|
||||
:param seed: 设置的随机数种子
|
||||
:param kwargs: fastNLP 保留使用
|
||||
"""
|
||||
super().__init__(dataset=dataset, **kwargs)
|
||||
if isinstance(dataset, DataSet):
|
||||
length = dataset.get_field(length)
|
||||
if not isinstance(length[0], int):
|
||||
length = list(map(len, length))
|
||||
else:
|
||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
|
||||
"the length parameter can only be List[int]"
|
||||
|
||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
|
||||
|
||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
|
||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
return self.sorted_indices
|
||||
|
||||
def __iter__(self):
|
||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
|
||||
self.num_consumed_samples = 0
|
||||
self.during_iter = True
|
||||
indices = self.generate_indices()
|
||||
|
||||
if self.pad:
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.num_consumed_samples:]
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == self.num_left_samples
|
||||
|
||||
for index in indices:
|
||||
self.num_consumed_samples += self.num_replicas
|
||||
yield index
|
||||
self.during_iter = False
|
||||
self.num_consumed_samples = 0
|
||||
|
||||
|
@ -7,7 +7,6 @@ __all__ = [
|
||||
"SortedSampler",
|
||||
'ConstTokenNumSampler',
|
||||
"ConstantTokenNumSampler",
|
||||
"UnrepeatedDistributedSampler",
|
||||
]
|
||||
|
||||
from itertools import chain
|
||||
@ -18,7 +17,7 @@ import numpy as np
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
from torch.utils.data import SequentialSampler, Sampler, RandomSampler
|
||||
from torch.utils.data import Sampler
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
|
||||
|
||||
@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets):
|
||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
|
||||
bucket_data[bucket_id].append(idx)
|
||||
return bucket_data
|
||||
|
||||
|
||||
class UnrepeatedDistributedSampler:
|
||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0):
|
||||
"""
|
||||
考虑在多卡evaluate的场景下,不能重复sample。
|
||||
|
||||
:param dataset:
|
||||
:param shuffle:
|
||||
:param seed:
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
# 多卡的相关的参数
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
self.epoch = -1
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
|
||||
:return:
|
||||
"""
|
||||
num_common = len(self.dataset)//self.num_replicas
|
||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
|
||||
return self.num_samples
|
||||
|
||||
def __iter__(self):
|
||||
r"""
|
||||
当前使用num_consumed_samples做法会在交替使用的时候遇到问题;
|
||||
Example:
|
||||
>>> sampler = RandomSampler()
|
||||
>>> iter1 = iter(sampler)
|
||||
>>> iter2 = iter(sampler)
|
||||
>>> next(iter1)
|
||||
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化
|
||||
"""
|
||||
|
||||
indices = self.generate_indices()
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == len(self)
|
||||
|
||||
for index in indices:
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
"""
|
||||
生成随机序列
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(len(self.dataset)))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
return indices
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self.epoch = epoch
|
||||
|
||||
def set_distributed(self, num_replicas, rank):
|
||||
"""
|
||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
|
||||
|
||||
:param num_replicas:
|
||||
:param rank:
|
||||
:return:
|
||||
"""
|
||||
assert num_replicas>0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0<=rank<num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
return self
|
143
fastNLP/core/samplers/unrepeated_sampler.py
Normal file
143
fastNLP/core/samplers/unrepeated_sampler.py
Normal file
@ -0,0 +1,143 @@
|
||||
__all__ = [
|
||||
'UnrepeatedSampler',
|
||||
'UnrepeatedSortedSampler',
|
||||
'UnrepeatedRandomSampler',
|
||||
"UnrepeatedSequentialSampler"
|
||||
]
|
||||
|
||||
from typing import List, Union
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class UnrepeatedSampler:
|
||||
"""
|
||||
在多卡场景下保证 indice 不重复的 sampler
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class UnrepeatedRandomSampler(UnrepeatedSampler):
|
||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
|
||||
"""
|
||||
考虑在多卡evaluate的场景下,不能重复sample。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
|
||||
:param seed: 设置的随机数种子
|
||||
:param kwargs: fastNLP 保留使用
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
# 多卡的相关的参数
|
||||
self.num_replicas = kwargs.get('num_replicas', 1)
|
||||
self.rank = kwargs.get('rank', 0)
|
||||
self.epoch = kwargs.get('epoch', -1)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
|
||||
:return:
|
||||
"""
|
||||
num_common = len(self.dataset)//self.num_replicas
|
||||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
|
||||
return num_samples
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.generate_indices()
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
assert len(indices) == len(self)
|
||||
|
||||
for index in indices:
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
"""
|
||||
生成随机序列
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(len(self.dataset)))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
return indices
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
self.epoch = epoch
|
||||
|
||||
def set_distributed(self, num_replicas, rank):
|
||||
"""
|
||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
|
||||
|
||||
:param num_replicas:
|
||||
:param rank:
|
||||
:return:
|
||||
"""
|
||||
assert num_replicas>0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0<=rank<num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
|
||||
def __init__(self, dataset, length:Union[str, List], **kwargs):
|
||||
"""
|
||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的
|
||||
batch 数量不完全一致。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
|
||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
|
||||
:param kwargs: fastNLP 保留使用
|
||||
"""
|
||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
|
||||
if isinstance(dataset, DataSet):
|
||||
length = dataset.get_field(length)
|
||||
if not isinstance(length[0], int):
|
||||
length = list(map(len, length))
|
||||
else:
|
||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
|
||||
"the length parameter can only be List[int]"
|
||||
|
||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
|
||||
|
||||
length = np.array(length, dtype=int) # 按照长到短排列的序号。
|
||||
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
return self.sorted_indices
|
||||
|
||||
|
||||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
|
||||
def __init__(self, dataset, **kwargs):
|
||||
"""
|
||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。
|
||||
|
||||
:param dataset: 实现了 __len__ 方法的数据容器。
|
||||
:param kwargs:
|
||||
"""
|
||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.generate_indices()
|
||||
indices = indices[self.rank:len(indices):self.num_replicas]
|
||||
for index in indices:
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
return list(range(len(self.dataset)))
|
||||
|
42
fastNLP/core/samplers/utils.py
Normal file
42
fastNLP/core/samplers/utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
__all__ = [
|
||||
're_instantiate_sampler',
|
||||
'conversion_between_reproducible_and_unrepeated_sampler'
|
||||
]
|
||||
|
||||
from fastNLP.core.samplers.unrepeated_sampler import *
|
||||
from fastNLP.core.samplers.reproducible_sampler import *
|
||||
|
||||
|
||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler):
|
||||
"""
|
||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
|
||||
ReproducibleSampler,
|
||||
|
||||
:param sampler:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
|
||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
|
||||
if isinstance(sampler, UnrepeatedSampler):
|
||||
if isinstance(sampler, UnrepeatedRandomSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
|
||||
elif isinstance(sampler, UnrepeatedSequentialSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
|
||||
elif isinstance(sampler, UnrepeatedSortedSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
|
||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
|
||||
else:
|
||||
if isinstance(sampler, RandomSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
|
||||
elif isinstance(sampler, SequentialSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
|
||||
elif isinstance(sampler, SortedSampler):
|
||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
|
||||
raise TypeError(f"{sampler.__class__} has no reproducible version.")
|
||||
|
||||
|
||||
def re_instantiate_sampler(sampler, new_sampler_class=None):
|
||||
all_attributes = vars(sampler)
|
||||
if new_sampler_class is not None:
|
||||
return new_sampler_class(**all_attributes)
|
||||
return type(sampler)(**all_attributes)
|
@ -96,6 +96,7 @@ class FRichProgress(Progress, metaclass=Singleton):
|
||||
|
||||
# start new
|
||||
self.start()
|
||||
self.console.show_cursor(show=True)
|
||||
return self
|
||||
|
||||
def set_transient(self, transient: bool = True):
|
||||
@ -149,6 +150,9 @@ class FRichProgress(Progress, metaclass=Singleton):
|
||||
super().stop_task(task_id)
|
||||
super().remove_task(task_id)
|
||||
|
||||
def start(self) -> None:
|
||||
super().start()
|
||||
self.console.show_cursor(show=True)
|
||||
|
||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
|
||||
f_rich_progress = FRichProgress().new_progess(
|
||||
@ -161,7 +165,7 @@ if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
|
||||
TextColumn("{task.fields[post_desc]}", justify="right"),
|
||||
transient=True,
|
||||
disable=False,
|
||||
speed_estimate_period=10
|
||||
speed_estimate_period=1
|
||||
)
|
||||
else:
|
||||
f_rich_progress = DummyFRichProgress()
|
||||
|
@ -44,6 +44,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_fn_arg_names(fn: Callable) -> List[str]:
|
||||
r"""
|
||||
返回一个函数的所有参数的名字;
|
||||
|
@ -153,7 +153,7 @@ def seed_jittor_global_seed(global_seed):
|
||||
pass
|
||||
|
||||
|
||||
def dump_fastnlp_backend(default:bool = False):
|
||||
def dump_fastnlp_backend(default:bool = False, backend=None):
|
||||
"""
|
||||
将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下,
|
||||
若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。
|
||||
@ -165,6 +165,7 @@ def dump_fastnlp_backend(default:bool = False):
|
||||
会保存的环境变量为 FASTNLP_BACKEND 。
|
||||
|
||||
:param default:
|
||||
:param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
@ -179,10 +180,16 @@ def dump_fastnlp_backend(default:bool = False):
|
||||
os.makedirs(os.path.dirname(env_path), exist_ok=True)
|
||||
|
||||
envs = {}
|
||||
if FASTNLP_BACKEND in os.environ:
|
||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND]
|
||||
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now."
|
||||
if backend is None:
|
||||
if FASTNLP_BACKEND in os.environ:
|
||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND]
|
||||
else:
|
||||
envs[FASTNLP_BACKEND] = backend
|
||||
if len(envs):
|
||||
with open(env_path, 'w', encoding='utf8') as f:
|
||||
json.dump(fp=f, obj=envs)
|
||||
|
||||
print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.")
|
||||
else:
|
||||
raise RuntimeError("No backend specified.")
|
@ -47,7 +47,8 @@ def set_env_on_import_paddle():
|
||||
# TODO jittor may need set this
|
||||
def set_env_on_import_jittor():
|
||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH
|
||||
pass
|
||||
if 'log_silent' not in os.environ:
|
||||
os.environ['log_silent'] = '1'
|
||||
|
||||
|
||||
def set_env_on_import():
|
||||
@ -63,7 +64,7 @@ def set_env_on_import():
|
||||
|
||||
# fastNLP 内部使用的一些变量
|
||||
if FASTNLP_LAUNCH_TIME not in os.environ:
|
||||
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}"
|
||||
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}"
|
||||
os.environ[FASTNLP_LAUNCH_TIME] = cur_time
|
||||
|
||||
# 设置对应的值
|
||||
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback
|
||||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
|
||||
|
||||
@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1(
|
||||
version,
|
||||
only_state_dict
|
||||
):
|
||||
# def test_model_checkpoint_callback_1(
|
||||
# model_and_optimizers: TrainerParameters,
|
||||
# driver='torch_ddp',
|
||||
# device=[0, 1],
|
||||
# version=1,
|
||||
# only_state_dict=True
|
||||
# ):
|
||||
path = Path.cwd().joinpath(f"test_model_checkpoint")
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if version == 0:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
ModelCheckpointCallback(
|
||||
monitor="acc",
|
||||
save_folder=path,
|
||||
save_every_n_epochs=1,
|
||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
|
||||
save_every_n_batches=123, # 避免和 epoch 的保存重复;
|
||||
save_topk=None,
|
||||
save_last=False,
|
||||
save_on_exception=None,
|
||||
@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1(
|
||||
]
|
||||
elif version == 1:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
ModelCheckpointCallback(
|
||||
monitor="acc",
|
||||
save_folder=path,
|
||||
save_every_n_epochs=3,
|
||||
save_every_n_global_batches=None,
|
||||
save_every_n_batches=None,
|
||||
save_topk=2,
|
||||
save_last=True,
|
||||
save_on_exception=None,
|
||||
@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1(
|
||||
input_mapping=model_and_optimizers.input_mapping,
|
||||
output_mapping=model_and_optimizers.output_mapping,
|
||||
metrics=model_and_optimizers.metrics,
|
||||
|
||||
n_epochs=10,
|
||||
callbacks=callbacks,
|
||||
output_from_new_proc="all"
|
||||
@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1(
|
||||
if version == 0:
|
||||
|
||||
if driver == "torch":
|
||||
assert "epoch_10-global_batch_250-acc" in all_saved_model_paths
|
||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
|
||||
assert "model-epoch_10" in all_saved_model_paths
|
||||
assert "model-epoch_4-batch_123" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"]
|
||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
|
||||
epoch_save_path = all_saved_model_paths["model-epoch_10"]
|
||||
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"]
|
||||
|
||||
assert len(all_saved_model_paths) == 12
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "epoch_6-global_batch_78-acc" in all_saved_model_paths
|
||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
|
||||
assert "model-epoch_6" in all_saved_model_paths
|
||||
assert "model-epoch_9-batch_123" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"]
|
||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
|
||||
epoch_save_path = all_saved_model_paths["model-epoch_6"]
|
||||
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]
|
||||
|
||||
assert len(all_saved_model_paths) == 11
|
||||
all_state_dicts = [epoch_save_path, step_save_path]
|
||||
|
||||
elif version == 1:
|
||||
|
||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
|
||||
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
|
||||
|
||||
if driver == "torch":
|
||||
assert "epoch_9-global_batch_225-acc" in all_saved_model_paths
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "model-epoch_9" in all_saved_model_paths
|
||||
assert "model-last" in all_saved_model_paths
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
each_folder_name = pattern.findall(each_folder_name)
|
||||
@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 2
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"]
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
epoch_save_path = all_saved_model_paths["model-epoch_9"]
|
||||
last_save_path = all_saved_model_paths["model-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 6
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "epoch_9-global_batch_117-acc" in all_saved_model_paths
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "model-epoch_9" in all_saved_model_paths
|
||||
assert "model-last" in all_saved_model_paths
|
||||
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 2
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"]
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
epoch_save_path = all_saved_model_paths["model-epoch_9"]
|
||||
last_save_path = all_saved_model_paths["model-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 6
|
||||
@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1(
|
||||
|
||||
finally:
|
||||
synchronize_safe_rm(path)
|
||||
# pass
|
||||
pass
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2(
|
||||
raise NotImplementedError
|
||||
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
ModelCheckpointCallback(
|
||||
monitor="acc1",
|
||||
save_folder=path,
|
||||
save_every_n_epochs=None,
|
||||
save_every_n_global_batches=None,
|
||||
save_every_n_batches=None,
|
||||
save_topk=None,
|
||||
save_last=False,
|
||||
save_on_exception=NotImplementedError,
|
||||
@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2(
|
||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
|
||||
|
||||
if driver == "torch":
|
||||
assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths
|
||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"]
|
||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths
|
||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"]
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths
|
||||
exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"]
|
||||
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths
|
||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"]
|
||||
|
||||
assert len(all_saved_model_paths) == 1
|
||||
all_state_dicts = [exception_model_path]
|
||||
@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1(
|
||||
|
||||
if version == 0:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
TrainerCheckpointCallback(
|
||||
monitor="acc",
|
||||
is_trainer_checkpoint=True,
|
||||
save_folder=path,
|
||||
save_every_n_epochs=7,
|
||||
save_every_n_global_batches=123, # 避免和 epoch 的保存重复;
|
||||
save_every_n_batches=123, # 避免和 epoch 的保存重复;
|
||||
save_topk=None,
|
||||
save_last=False,
|
||||
save_on_exception=None,
|
||||
@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1(
|
||||
]
|
||||
elif version == 1:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
TrainerCheckpointCallback(
|
||||
monitor="acc",
|
||||
is_trainer_checkpoint=True,
|
||||
save_folder=path,
|
||||
save_every_n_epochs=None,
|
||||
save_every_n_global_batches=None,
|
||||
save_every_n_batches=None,
|
||||
save_topk=2,
|
||||
save_last=True,
|
||||
save_on_exception=None,
|
||||
@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1(
|
||||
if version == 0:
|
||||
|
||||
if driver == "torch":
|
||||
assert "epoch_7-global_batch_175-acc" in all_saved_model_paths
|
||||
assert "epoch_4-global_batch_123-acc" in all_saved_model_paths
|
||||
assert "trainer-epoch_7" in all_saved_model_paths
|
||||
assert "trainer-epoch_4-batch_123" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"]
|
||||
step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"]
|
||||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
|
||||
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"]
|
||||
|
||||
assert len(all_saved_model_paths) == 3
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "epoch_7-global_batch_91-acc" in all_saved_model_paths
|
||||
assert "epoch_9-global_batch_123-acc" in all_saved_model_paths
|
||||
assert "trainer-epoch_7" in all_saved_model_paths
|
||||
assert "trainer-epoch_9-batch_123" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"]
|
||||
step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"]
|
||||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
|
||||
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"]
|
||||
|
||||
assert len(all_saved_model_paths) == 2
|
||||
all_state_dicts = [epoch_save_path, step_save_path]
|
||||
|
||||
elif version == 1:
|
||||
|
||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
|
||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
|
||||
|
||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
|
||||
if driver == "torch":
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "trainer-last" in all_saved_model_paths
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
each_folder_name = pattern.findall(each_folder_name)
|
||||
@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 2
|
||||
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
last_save_path = all_saved_model_paths["trainer-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 3
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "trainer-last" in all_saved_model_paths
|
||||
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 2
|
||||
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
last_save_path = all_saved_model_paths["trainer-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 3
|
||||
@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2(
|
||||
device,
|
||||
version
|
||||
):
|
||||
pytest.skip("Skip transformers test for now.")
|
||||
path = Path.cwd().joinpath(f"test_model_checkpoint")
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
import transformers
|
||||
import transformers # 版本4.16.2
|
||||
import torch
|
||||
from torchmetrics import Accuracy
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2(
|
||||
|
||||
if version == 0:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
TrainerCheckpointCallback(
|
||||
monitor="acc",
|
||||
is_trainer_checkpoint=True,
|
||||
save_folder=path,
|
||||
save_every_n_epochs=None,
|
||||
save_every_n_global_batches=50,
|
||||
save_every_n_batches=50,
|
||||
save_topk=None,
|
||||
save_last=False,
|
||||
save_on_exception=None,
|
||||
@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2(
|
||||
]
|
||||
elif version == 1:
|
||||
callbacks = [
|
||||
CheckpointCallback(
|
||||
TrainerCheckpointCallback(
|
||||
monitor="acc",
|
||||
is_trainer_checkpoint=True,
|
||||
save_folder=path,
|
||||
save_every_n_epochs=None,
|
||||
save_every_n_global_batches=None,
|
||||
save_every_n_batches=None,
|
||||
save_topk=1,
|
||||
save_last=True,
|
||||
save_on_exception=None,
|
||||
@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2(
|
||||
if version == 0:
|
||||
|
||||
if driver == "torch":
|
||||
assert "epoch_1-global_batch_200-acc" in all_saved_model_paths
|
||||
assert "trainer-epoch_1-batch_200" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"]
|
||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"]
|
||||
|
||||
assert len(all_saved_model_paths) == 4
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "epoch_1-global_batch_100-acc" in all_saved_model_paths
|
||||
assert "trainer-epoch_1-batch_100" in all_saved_model_paths
|
||||
|
||||
epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"]
|
||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"]
|
||||
|
||||
assert len(all_saved_model_paths) == 2
|
||||
all_state_dicts = [epoch_save_path]
|
||||
|
||||
elif version == 1:
|
||||
|
||||
pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*")
|
||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")
|
||||
|
||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
|
||||
if driver == "torch":
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "trainer-last" in all_saved_model_paths
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
each_folder_name = pattern.findall(each_folder_name)
|
||||
@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 1
|
||||
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
last_save_path = all_saved_model_paths["trainer-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 2
|
||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
|
||||
else:
|
||||
assert "last" in all_saved_model_paths
|
||||
assert "trainer-last" in all_saved_model_paths
|
||||
|
||||
aLL_topk_folders = []
|
||||
for each_folder_name in all_saved_model_paths:
|
||||
@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2(
|
||||
aLL_topk_folders.append(each_folder_name[0])
|
||||
assert len(aLL_topk_folders) == 1
|
||||
|
||||
last_save_path = all_saved_model_paths["last"]
|
||||
last_save_path = all_saved_model_paths["trainer-last"]
|
||||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]
|
||||
|
||||
assert len(all_saved_model_paths) == 2
|
||||
|
@ -1,4 +1,4 @@
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
|
||||
from fastNLP.core.dataset import DataSet
|
||||
@ -17,7 +17,7 @@ class RandomDataset(Dataset):
|
||||
return 10
|
||||
|
||||
|
||||
class TestPaddle(unittest.TestCase):
|
||||
class TestPaddle:
|
||||
|
||||
def test_init(self):
|
||||
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})
|
||||
|
@ -1,25 +1,25 @@
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader
|
||||
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.io.data_bundle import DataBundle
|
||||
|
||||
|
||||
class TestFdl(unittest.TestCase):
|
||||
class TestFdl:
|
||||
|
||||
def test_init_v1(self):
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
|
||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True)
|
||||
# for batch in fdl:
|
||||
# print(batch)
|
||||
fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True)
|
||||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True)
|
||||
# for batch in fdl1:
|
||||
# print(batch)
|
||||
|
||||
def test_set_padding(self):
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
ds.set_pad_val("x", val=-1)
|
||||
fdl = FDataLoader(ds, batch_size=3)
|
||||
fdl = TorchDataLoader(ds, batch_size=3)
|
||||
fdl.set_input("x", "y")
|
||||
for batch in fdl:
|
||||
print(batch)
|
||||
@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase):
|
||||
_dict["Y"].append(ins['y'])
|
||||
return _dict
|
||||
|
||||
fdl = FDataLoader(ds, batch_size=3, as_numpy=True)
|
||||
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True)
|
||||
fdl.set_input("x", "y")
|
||||
fdl.add_collator(collate_fn)
|
||||
for batch in fdl:
|
||||
@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase):
|
||||
|
||||
def test_get_batch_indices(self):
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
fdl = FDataLoader(ds, batch_size=3, shuffle=True)
|
||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True)
|
||||
fdl.set_input("y", "x")
|
||||
for batch in fdl:
|
||||
print(fdl.get_batch_indices())
|
||||
@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase):
|
||||
return object.__getattribute__(self, item)
|
||||
|
||||
dataset = _DataSet()
|
||||
dl = FDataLoader(dataset, batch_size=2, shuffle=True)
|
||||
dl = TorchDataLoader(dataset, batch_size=2, shuffle=True)
|
||||
# dl.set_inputs('data', 'labels')
|
||||
# dl.set_pad_val('labels', val=None)
|
||||
for batch in dl:
|
||||
print(batch)
|
||||
print(dl.get_batch_indices())
|
||||
|
||||
def test_prepare_dataloader(self):
|
||||
def test_prepare_torch_dataloader(self):
|
||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
|
||||
assert isinstance(dl, FDataLoader)
|
||||
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2)
|
||||
assert isinstance(dl, TorchDataLoader)
|
||||
|
||||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
|
||||
dbl = DataBundle(datasets={'train': ds, 'val': ds1})
|
||||
dl_bundle = prepare_dataloader(dbl)
|
||||
assert isinstance(dl_bundle['train'], FDataLoader)
|
||||
assert isinstance(dl_bundle['val'], FDataLoader)
|
||||
dl_bundle = prepare_torch_dataloader(dbl)
|
||||
assert isinstance(dl_bundle['train'], TorchDataLoader)
|
||||
assert isinstance(dl_bundle['val'], TorchDataLoader)
|
||||
|
||||
ds_dict = {'train_1': ds, 'val': ds1}
|
||||
dl_dict = prepare_dataloader(ds_dict)
|
||||
assert isinstance(dl_dict['train_1'], FDataLoader)
|
||||
assert isinstance(dl_dict['val'], FDataLoader)
|
||||
dl_dict = prepare_torch_dataloader(ds_dict)
|
||||
assert isinstance(dl_dict['train_1'], TorchDataLoader)
|
||||
assert isinstance(dl_dict['val'], TorchDataLoader)
|
||||
|
||||
sequence = [ds, ds1]
|
||||
seq_ds = prepare_dataloader(sequence)
|
||||
assert isinstance(seq_ds[0], FDataLoader)
|
||||
assert isinstance(seq_ds[1], FDataLoader)
|
||||
seq_ds = prepare_torch_dataloader(sequence)
|
||||
assert isinstance(seq_ds[0], TorchDataLoader)
|
||||
assert isinstance(seq_ds[1], TorchDataLoader)
|
||||
|
@ -1,12 +1,12 @@
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
|
||||
|
||||
|
||||
class TestDataSetInit(unittest.TestCase):
|
||||
class TestDataSetInit:
|
||||
"""初始化DataSet的办法有以下几种:
|
||||
1) 用dict:
|
||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
|
||||
@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase):
|
||||
def test_init_v1(self):
|
||||
# 一维list
|
||||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
|
||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
|
||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
|
||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
|
||||
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
|
||||
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
|
||||
assert ds.field_arrays["y"].content == [[5, 6], ] * 40
|
||||
|
||||
def test_init_v2(self):
|
||||
# 用dict
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
|
||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
|
||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
|
||||
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
|
||||
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
|
||||
assert ds.field_arrays["y"].content == [[5, 6], ] * 40
|
||||
|
||||
def test_init_assert(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
with pytest.raises(AssertionError):
|
||||
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
|
||||
with self.assertRaises(AssertionError):
|
||||
with pytest.raises(AssertionError):
|
||||
_ = DataSet([[1, 2, 3, 4]] * 10)
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
_ = DataSet(0.00001)
|
||||
|
||||
|
||||
class TestDataSetMethods(unittest.TestCase):
|
||||
class TestDataSetMethods:
|
||||
def test_append(self):
|
||||
dd = DataSet()
|
||||
for _ in range(3):
|
||||
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
|
||||
self.assertEqual(len(dd), 3)
|
||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
|
||||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
|
||||
assert len(dd) == 3
|
||||
assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
|
||||
assert dd.field_arrays["y"].content == [[5, 6]] * 3
|
||||
|
||||
def test_add_field(self):
|
||||
dd = DataSet()
|
||||
dd.add_field("x", [[1, 2, 3]] * 10)
|
||||
dd.add_field("y", [[1, 2, 3, 4]] * 10)
|
||||
dd.add_field("z", [[5, 6]] * 10)
|
||||
self.assertEqual(len(dd), 10)
|
||||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
|
||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
|
||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
|
||||
assert len(dd) == 10
|
||||
assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
|
||||
assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
|
||||
assert dd.field_arrays["z"].content == [[5, 6]] * 10
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
with pytest.raises(RuntimeError):
|
||||
dd.add_field("??", [[1, 2]] * 40)
|
||||
|
||||
def test_delete_field(self):
|
||||
@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
dd.add_field("x", [[1, 2, 3]] * 10)
|
||||
dd.add_field("y", [[1, 2, 3, 4]] * 10)
|
||||
dd.delete_field("x")
|
||||
self.assertFalse("x" in dd.field_arrays)
|
||||
self.assertTrue("y" in dd.field_arrays)
|
||||
assert ("x" in dd.field_arrays) == False
|
||||
assert "y" in dd.field_arrays
|
||||
|
||||
def test_delete_instance(self):
|
||||
dd = DataSet()
|
||||
@ -80,99 +80,113 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
dd.add_field("x", [[1, 2, 3]] * old_length)
|
||||
dd.add_field("y", [[1, 2, 3, 4]] * old_length)
|
||||
dd.delete_instance(0)
|
||||
self.assertEqual(len(dd), old_length - 1)
|
||||
assert len(dd) == old_length - 1
|
||||
dd.delete_instance(0)
|
||||
self.assertEqual(len(dd), old_length - 2)
|
||||
assert len(dd) == old_length - 2
|
||||
|
||||
def test_getitem(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
ins_1, ins_0 = ds[0], ds[1]
|
||||
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
|
||||
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
|
||||
self.assertEqual(ins_1["y"], [5, 6])
|
||||
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
|
||||
self.assertEqual(ins_0["y"], [5, 6])
|
||||
assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
|
||||
assert ins_1["x"] == [1, 2, 3, 4]
|
||||
assert ins_1["y"] == [5, 6]
|
||||
assert ins_0["x"] == [1, 2, 3, 4]
|
||||
assert ins_0["y"] == [5, 6]
|
||||
|
||||
sub_ds = ds[:10]
|
||||
self.assertTrue(isinstance(sub_ds, DataSet))
|
||||
self.assertEqual(len(sub_ds), 10)
|
||||
assert isinstance(sub_ds, DataSet) == True
|
||||
assert len(sub_ds) == 10
|
||||
|
||||
sub_ds_1 = ds[[10, 0, 2, 3]]
|
||||
self.assertTrue(isinstance(sub_ds_1, DataSet))
|
||||
self.assertEqual(len(sub_ds_1), 4)
|
||||
assert isinstance(sub_ds_1, DataSet) == True
|
||||
assert len(sub_ds_1) == 4
|
||||
|
||||
field_array = ds['x']
|
||||
self.assertTrue(isinstance(field_array, FieldArray))
|
||||
self.assertEqual(len(field_array), 40)
|
||||
assert isinstance(field_array, FieldArray) == True
|
||||
assert len(field_array) == 40
|
||||
|
||||
def test_setitem(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
ds.add_field('i', list(range(len(ds))))
|
||||
assert ds.get_field('i').content == list(range(len(ds)))
|
||||
import random
|
||||
random.shuffle(ds)
|
||||
import numpy as np
|
||||
np.random.shuffle(ds)
|
||||
assert ds.get_field('i').content != list(range(len(ds)))
|
||||
|
||||
ins1 = ds[1]
|
||||
ds[2] = ds[1]
|
||||
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']
|
||||
|
||||
def test_get_item_error(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
with pytest.raises(RuntimeError):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||
_ = ds[40:]
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
with pytest.raises(KeyError):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||
_ = ds["kom"]
|
||||
|
||||
def test_len_(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
self.assertEqual(len(ds), 40)
|
||||
assert len(ds) == 40
|
||||
|
||||
ds = DataSet()
|
||||
self.assertEqual(len(ds), 0)
|
||||
assert len(ds) == 0
|
||||
|
||||
def test_add_fieldarray(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40))
|
||||
self.assertEqual(ds['z'].content, [[7, 8]]*40)
|
||||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
|
||||
assert ds['z'].content == [[7, 8]] * 40
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10))
|
||||
with pytest.raises(RuntimeError):
|
||||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
with pytest.raises(TypeError):
|
||||
ds.add_fieldarray('z', [1, 2, 4])
|
||||
|
||||
def test_copy_field(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
ds.copy_field('x', 'z')
|
||||
self.assertEqual(ds['x'].content, ds['z'].content)
|
||||
assert ds['x'].content == ds['z'].content
|
||||
|
||||
def test_has_field(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
self.assertTrue(ds.has_field('x'))
|
||||
self.assertFalse(ds.has_field('z'))
|
||||
assert ds.has_field('x') == True
|
||||
assert ds.has_field('z') == False
|
||||
|
||||
def test_get_field(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
with self.assertRaises(KeyError):
|
||||
with pytest.raises(KeyError):
|
||||
ds.get_field('z')
|
||||
x_array = ds.get_field('x')
|
||||
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
|
||||
assert x_array.content == [[1, 2, 3, 4]] * 40
|
||||
|
||||
def test_get_all_fields(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
field_arrays = ds.get_all_fields()
|
||||
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
|
||||
self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
|
||||
assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
|
||||
assert field_arrays['y'].content == [[5, 6]] * 40
|
||||
|
||||
def test_get_field_names(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
field_names = ds.get_field_names()
|
||||
self.assertTrue('x' in field_names)
|
||||
self.assertTrue('y' in field_names)
|
||||
assert 'x' in field_names
|
||||
assert 'y' in field_names
|
||||
|
||||
def test_apply(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
|
||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
|
||||
self.assertTrue("rx" in ds.field_arrays)
|
||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
|
||||
assert ("rx" in ds.field_arrays) == True
|
||||
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]
|
||||
|
||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
|
||||
self.assertEqual(ds.field_arrays["y"].content[0], 2)
|
||||
assert ds.field_arrays["y"].content[0] == 2
|
||||
|
||||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
|
||||
self.assertTrue(isinstance(res, list) and len(res) > 0)
|
||||
self.assertTrue(res[0], 4)
|
||||
assert (isinstance(res, list) and len(res) > 0) == True
|
||||
assert res[0] == 4
|
||||
|
||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
|
||||
# expect no exception raised
|
||||
@ -192,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
|
||||
def modify_inplace(instance):
|
||||
instance['words'] = 1
|
||||
|
||||
ds.apply(modify_inplace)
|
||||
# with self.assertRaises(TypeError):
|
||||
# ds.apply(modify_inplace)
|
||||
@ -216,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
|
||||
T.apply_more(func_1)
|
||||
# print(T['c'][0, 1, 2])
|
||||
self.assertEqual(list(T["c"].content), [2, 4, 6])
|
||||
self.assertEqual(list(T["d"].content), [1, 4, 9])
|
||||
assert list(T["c"].content) == [2, 4, 6]
|
||||
assert list(T["d"].content) == [1, 4, 9]
|
||||
|
||||
res = T.apply_field_more(func_2, "a", modify_fields=False)
|
||||
self.assertEqual(list(T["c"].content), [2, 4, 6])
|
||||
self.assertEqual(list(T["d"].content), [1, 4, 9])
|
||||
self.assertEqual(list(res["c"]), [3, 6, 9])
|
||||
self.assertEqual(list(res["d"]), [1, 8, 27])
|
||||
assert list(T["c"].content) == [2, 4, 6]
|
||||
assert list(T["d"].content) == [1, 4, 9]
|
||||
assert list(res["c"]) == [3, 6, 9]
|
||||
assert list(res["d"]) == [1, 8, 27]
|
||||
|
||||
with self.assertRaises(ApplyResultException) as e:
|
||||
with pytest.raises(ApplyResultException) as e:
|
||||
T.apply_more(func_err_1)
|
||||
print(e)
|
||||
|
||||
with self.assertRaises(ApplyResultException) as e:
|
||||
with pytest.raises(ApplyResultException) as e:
|
||||
T.apply_field_more(func_err_2, "a")
|
||||
print(e)
|
||||
|
||||
def test_drop(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
|
||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
|
||||
self.assertEqual(len(ds), 20)
|
||||
assert len(ds) == 20
|
||||
|
||||
def test_contains(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
self.assertTrue("x" in ds)
|
||||
self.assertTrue("y" in ds)
|
||||
self.assertFalse("z" in ds)
|
||||
assert ("x" in ds) == True
|
||||
assert ("y" in ds) == True
|
||||
assert ("z" in ds) == False
|
||||
|
||||
def test_rename_field(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||
ds.rename_field("x", "xx")
|
||||
self.assertTrue("xx" in ds)
|
||||
self.assertFalse("x" in ds)
|
||||
assert ("xx" in ds) == True
|
||||
assert ("x" in ds) == False
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
with pytest.raises(KeyError):
|
||||
ds.rename_field("yyy", "oo")
|
||||
|
||||
def test_split(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||
d1, d2 = ds.split(0.1)
|
||||
self.assertEqual(len(d1), len(ds)*0.9)
|
||||
self.assertEqual(len(d2), len(ds)*0.1)
|
||||
assert len(d2) == (len(ds) * 0.9)
|
||||
assert len(d1) == (len(ds) * 0.1)
|
||||
|
||||
def test_add_field_v2(self):
|
||||
ds = DataSet({"x": [3, 4]})
|
||||
@ -268,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
def test_save_load(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||
ds.save("./my_ds.pkl")
|
||||
self.assertTrue(os.path.exists("./my_ds.pkl"))
|
||||
assert os.path.exists("./my_ds.pkl") == True
|
||||
|
||||
ds_1 = DataSet.load("./my_ds.pkl")
|
||||
os.remove("my_ds.pkl")
|
||||
|
||||
def test_add_null(self):
|
||||
ds = DataSet()
|
||||
with self.assertRaises(RuntimeError) as RE:
|
||||
with pytest.raises(RuntimeError) as RE:
|
||||
ds.add_field('test', [])
|
||||
|
||||
def test_concat(self):
|
||||
@ -287,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
|
||||
ds3 = ds1.concat(ds2)
|
||||
|
||||
self.assertEqual(len(ds3), 20)
|
||||
assert len(ds3) == 20
|
||||
|
||||
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
|
||||
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
|
||||
assert ds1[9]['x'] == [1, 2, 3, 4]
|
||||
assert ds1[10]['x'] == [4, 3, 2, 1]
|
||||
|
||||
ds2[0]['x'][0] = 100
|
||||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
|
||||
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
|
||||
|
||||
ds3[10]['x'][0] = -100
|
||||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
|
||||
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
|
||||
|
||||
# 测试inplace
|
||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
|
||||
@ -304,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
ds3 = ds1.concat(ds2, inplace=True)
|
||||
|
||||
ds2[0]['x'][0] = 100
|
||||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
|
||||
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了
|
||||
|
||||
ds3[10]['x'][0] = -100
|
||||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
|
||||
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了
|
||||
|
||||
ds3[0]['x'][0] = 100
|
||||
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
|
||||
assert ds1[0]['x'][0] == 100 # 改变copy前的field了
|
||||
|
||||
# 测试mapping
|
||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
|
||||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
|
||||
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
|
||||
self.assertEqual(len(ds3), 20)
|
||||
assert len(ds3) == 20
|
||||
|
||||
# 测试忽略掉多余的
|
||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
|
||||
@ -326,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
# 测试报错
|
||||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
|
||||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
|
||||
with self.assertRaises(RuntimeError):
|
||||
with pytest.raises(RuntimeError):
|
||||
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})
|
||||
|
||||
def test_instance_field_disappear_bug(self):
|
||||
@ -334,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
data.copy_field(field_name='raw_chars', new_field_name='chars')
|
||||
_data = data[:1]
|
||||
for field_name in ['raw_chars', 'target', 'chars']:
|
||||
self.assertTrue(_data.has_field(field_name))
|
||||
assert _data.has_field(field_name) == True
|
||||
|
||||
def test_from_pandas(self):
|
||||
import pandas as pd
|
||||
@ -342,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
ds = DataSet.from_pandas(df)
|
||||
print(ds)
|
||||
self.assertEqual(ds['x'].content, [1, 2, 3])
|
||||
self.assertEqual(ds['y'].content, [4, 5, 6])
|
||||
assert ds['x'].content == [1, 2, 3]
|
||||
assert ds['y'].content == [4, 5, 6]
|
||||
|
||||
def test_to_pandas(self):
|
||||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
@ -352,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
def test_to_csv(self):
|
||||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
ds.to_csv("1.csv")
|
||||
self.assertTrue(os.path.exists("1.csv"))
|
||||
assert os.path.exists("1.csv") == True
|
||||
os.remove("1.csv")
|
||||
|
||||
def test_add_collate_fn(self):
|
||||
@ -360,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
|
||||
def collate_fn(item):
|
||||
return item
|
||||
ds.add_collate_fn(collate_fn)
|
||||
|
||||
self.assertEqual(len(ds.collate_fns.collators), 2)
|
||||
ds.add_collate_fn(collate_fn)
|
||||
|
||||
def test_get_collator(self):
|
||||
from typing import Callable
|
||||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
|
||||
collate_fn = ds.get_collator()
|
||||
self.assertEqual(isinstance(collate_fn, Callable), True)
|
||||
assert isinstance(collate_fn, Callable) == True
|
||||
|
||||
def test_add_seq_len(self):
|
||||
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
|
||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
|
||||
ds.add_seq_len('x')
|
||||
print(ds)
|
||||
|
||||
def test_set_target(self):
|
||||
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
|
||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
|
||||
ds.set_target('x')
|
||||
|
||||
|
||||
class TestFieldArrayInit(unittest.TestCase):
|
||||
class TestFieldArrayInit:
|
||||
"""
|
||||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
|
||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
|
||||
@ -428,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase):
|
||||
# list of array
|
||||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])
|
||||
|
||||
|
||||
def test_init_v8(self):
|
||||
# 二维list
|
||||
val = np.array([[1, 2], [3, 4]])
|
||||
@ -436,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase):
|
||||
fa.append(val)
|
||||
|
||||
|
||||
class TestFieldArray(unittest.TestCase):
|
||||
class TestFieldArray:
|
||||
def test_main(self):
|
||||
fa = FieldArray("x", [1, 2, 3, 4, 5])
|
||||
self.assertEqual(len(fa), 5)
|
||||
assert len(fa) == 5
|
||||
fa.append(6)
|
||||
self.assertEqual(len(fa), 6)
|
||||
assert len(fa) == 6
|
||||
|
||||
self.assertEqual(fa[-1], 6)
|
||||
self.assertEqual(fa[0], 1)
|
||||
assert fa[-1] == 6
|
||||
assert fa[0] == 1
|
||||
fa[-1] = 60
|
||||
self.assertEqual(fa[-1], 60)
|
||||
assert fa[-1] == 60
|
||||
|
||||
self.assertEqual(fa.get(0), 1)
|
||||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
|
||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
|
||||
assert fa.get(0) == 1
|
||||
assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
|
||||
assert list(fa.get([0, 1, 2])) == [1, 2, 3]
|
||||
|
||||
def test_getitem_v1(self):
|
||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
|
||||
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
ans = fa[[0, 1]]
|
||||
self.assertTrue(isinstance(ans, np.ndarray))
|
||||
self.assertTrue(isinstance(ans[0], np.ndarray))
|
||||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
|
||||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
|
||||
self.assertEqual(ans.dtype, np.float64)
|
||||
assert isinstance(ans, np.ndarray) == True
|
||||
assert isinstance(ans[0], np.ndarray) == True
|
||||
assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
assert ans[1].tolist() == [1, 2, 3, 4, 5]
|
||||
assert ans.dtype == np.float64
|
||||
|
||||
def test_getitem_v2(self):
|
||||
x = np.random.rand(10, 5)
|
||||
fa = FieldArray("my_field", x)
|
||||
indices = [0, 1, 3, 4, 6]
|
||||
for a, b in zip(fa[indices], x[indices]):
|
||||
self.assertListEqual(a.tolist(), b.tolist())
|
||||
assert a.tolist() == b.tolist()
|
||||
|
||||
def test_append(self):
|
||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
|
||||
self.assertEqual(len(fa), 3)
|
||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
|
||||
assert len(fa) == 3
|
||||
assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]
|
||||
|
||||
def test_pop(self):
|
||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
fa.pop(0)
|
||||
self.assertEqual(len(fa), 1)
|
||||
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
assert len(fa) == 1
|
||||
assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
|
||||
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
class TestCase:
|
||||
|
||||
def test_init(self):
|
||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
|
||||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
|
||||
self.assertTrue(isinstance(ins.fields, dict))
|
||||
self.assertEqual(ins.fields, fields)
|
||||
assert isinstance(ins.fields, dict) == True
|
||||
assert ins.fields == fields
|
||||
|
||||
ins = Instance(**fields)
|
||||
self.assertEqual(ins.fields, fields)
|
||||
assert ins.fields == fields
|
||||
|
||||
def test_add_field(self):
|
||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
|
||||
ins = Instance(**fields)
|
||||
ins.add_field("z", [1, 1, 1])
|
||||
fields.update({"z": [1, 1, 1]})
|
||||
self.assertEqual(ins.fields, fields)
|
||||
assert ins.fields == fields
|
||||
|
||||
def test_get_item(self):
|
||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
|
||||
ins = Instance(**fields)
|
||||
self.assertEqual(ins["x"], [1, 2, 3])
|
||||
self.assertEqual(ins["y"], [4, 5, 6])
|
||||
self.assertEqual(ins["z"], [1, 1, 1])
|
||||
assert ins["x"] == [1, 2, 3]
|
||||
assert ins["y"] == [4, 5, 6]
|
||||
assert ins["z"] == [1, 1, 1]
|
||||
|
||||
def test_repr(self):
|
||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
|
||||
|
@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
|
||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from fastNLP.core import synchronize_safe_rm
|
||||
|
@ -30,7 +30,7 @@ class SequenceDataSet:
|
||||
|
||||
|
||||
def check_replace_sampler(driver):
|
||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler
|
||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
|
||||
# reproducible 是 True 和 False
|
||||
|
||||
# 需要 check 返回的 sampler 和 dataloader 都不同了
|
||||
|
@ -118,7 +118,6 @@ class TestAccuracy:
|
||||
def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'],
|
||||
metric_kwargs: Dict[str, Any]) -> None:
|
||||
global pool
|
||||
print(pool)
|
||||
if is_ddp:
|
||||
if sys.platform == "win32":
|
||||
pytest.skip("DDP not supported on windows")
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
from collections import Counter
|
||||
import os, sys
|
||||
import copy
|
||||
@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
print(torch.cuda.device_count())
|
||||
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
@ -64,15 +64,15 @@ def find_free_network_port() -> int:
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture(scope='class', autouse=True)
|
||||
def pre_process():
|
||||
global pool
|
||||
pool = Pool(processes=NUM_PROCESSES)
|
||||
master_port = find_free_network_port()
|
||||
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
yield
|
||||
pool.close()
|
||||
pool.join()
|
||||
# @pytest.fixture(scope='class', autouse=True)
|
||||
# def pre_process():
|
||||
# global pool
|
||||
# pool = Pool(processes=NUM_PROCESSES)
|
||||
# master_port = find_free_network_port()
|
||||
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
# yield
|
||||
# pool.close()
|
||||
# pool.join()
|
||||
|
||||
|
||||
def _test(local_rank: int,
|
||||
@ -87,18 +87,19 @@ def _test(local_rank: int,
|
||||
# dataset 也类似(每个进程有自己的一个)
|
||||
dataset = copy.deepcopy(dataset)
|
||||
metric.to(device)
|
||||
print(os.environ.get("MASTER_PORT", "xx"))
|
||||
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上)
|
||||
for i in range(local_rank, len(dataset), world_size):
|
||||
pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len']
|
||||
print(tg, seq_len)
|
||||
metric.update(pred, tg, seq_len)
|
||||
|
||||
my_result = metric.get_metric()
|
||||
print(my_result)
|
||||
print(sklearn_metric)
|
||||
assert my_result == sklearn_metric
|
||||
|
||||
|
||||
class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
global pool
|
||||
class TestSpanFPreRecMetric:
|
||||
|
||||
def test_case1(self):
|
||||
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans
|
||||
@ -135,38 +136,36 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
|
||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
|
||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
|
||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
|
||||
-0.3782, 0.8240],
|
||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
|
||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
|
||||
-0.3782, 0.8240],
|
||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
|
||||
-0.3562, -1.4116],
|
||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
|
||||
2.0023, 0.7075],
|
||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
|
||||
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
|
||||
2.0023, 0.7075],
|
||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
|
||||
0.3832, -0.1540],
|
||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
|
||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
|
||||
-1.3508, -0.9513],
|
||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
|
||||
-0.0842, -0.4294]],
|
||||
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
|
||||
-0.0842, -0.4294]],
|
||||
|
||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
|
||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
|
||||
-1.4138, -0.8853],
|
||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
|
||||
-1.0726, 0.0364],
|
||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
|
||||
-0.8836, -0.9320],
|
||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
|
||||
-1.6857, 1.1571],
|
||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
|
||||
-0.5837, 1.0184],
|
||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
|
||||
-0.9025, 0.0864]]])
|
||||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4],
|
||||
[4, 1, 7, 0, 4, 7]])
|
||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
|
||||
-1.0726, 0.0364],
|
||||
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
|
||||
-0.8836, -0.9320],
|
||||
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
|
||||
-1.6857, 1.1571],
|
||||
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
|
||||
-0.5837, 1.0184],
|
||||
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
|
||||
-0.9025, 0.0864]]])
|
||||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]])
|
||||
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6])
|
||||
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5,
|
||||
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
|
||||
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}
|
||||
|
||||
assert expect_bio_res == fastnlp_bio_metric.get_metric()
|
||||
# print(fastnlp_bio_metric.get_metric())
|
||||
|
||||
@ -253,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
# print(expected_metric)
|
||||
metric_value = metric.get_metric()
|
||||
for key, value in expected_metric.items():
|
||||
self.assertAlmostEqual(value, metric_value[key], places=5)
|
||||
np.allclose(value, metric_value[key])
|
||||
|
||||
def test_auto_encoding_type_infer(self):
|
||||
# 检查是否可以自动check encode的类型
|
||||
@ -270,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
vocab.add_word('o')
|
||||
vocabs[encoding_type] = vocab
|
||||
for e in ['bio', 'bioes', 'bmeso']:
|
||||
with self.subTest(e=e):
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
|
||||
assert metric.encoding_type == e
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e])
|
||||
assert metric.encoding_type == e
|
||||
|
||||
bmes_vocab = _generate_tags('bmes')
|
||||
vocab = Vocabulary()
|
||||
@ -285,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
vocab = Vocabulary()
|
||||
for i in range(10):
|
||||
vocab.add_word(str(i))
|
||||
with self.assertRaises(Exception):
|
||||
with pytest.raises(Exception):
|
||||
metric = SpanFPreRecMetric(vocab)
|
||||
|
||||
def test_encoding_type(self):
|
||||
@ -304,65 +302,72 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
vocab.add_word('o')
|
||||
vocabs[encoding_type] = vocab
|
||||
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
|
||||
with self.subTest(e1=e1, e2=e2):
|
||||
if e1 == e2:
|
||||
if e1 == e2:
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
|
||||
else:
|
||||
s2 = set(e2)
|
||||
s2.update(set(e1))
|
||||
if s2 == set(e2):
|
||||
continue
|
||||
with pytest.raises(AssertionError):
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
|
||||
else:
|
||||
s2 = set(e2)
|
||||
s2.update(set(e1))
|
||||
if s2 == set(e2):
|
||||
continue
|
||||
with self.assertRaises(AssertionError):
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2)
|
||||
for encoding_type in ['bio', 'bioes', 'bmeso']:
|
||||
with self.assertRaises(AssertionError):
|
||||
with pytest.raises(AssertionError):
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes')
|
||||
|
||||
with self.assertWarns(Warning):
|
||||
with pytest.warns(Warning):
|
||||
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso')
|
||||
vocab = Vocabulary().add_word_lst(list('bmes'))
|
||||
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso')
|
||||
|
||||
def test_case5(self):
|
||||
global pool
|
||||
# pool = Pool(NUM_PROCESSES)
|
||||
# master_port = find_free_network_port()
|
||||
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
# global pool
|
||||
pool = Pool(NUM_PROCESSES)
|
||||
master_port = find_free_network_port()
|
||||
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)])
|
||||
number_labels = 4
|
||||
# bio tag
|
||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
|
||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
|
||||
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
|
||||
dataset = DataSet({'pred': [torch.FloatTensor(
|
||||
[[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
|
||||
-0.3782, 0.8240],
|
||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
|
||||
-0.3562, -1.4116],
|
||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
|
||||
2.0023, 0.7075],
|
||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
|
||||
0.3832, -0.1540],
|
||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
|
||||
-1.3508, -0.9513],
|
||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
|
||||
-0.0842, -0.4294]],
|
||||
dataset = DataSet({'pred': [
|
||||
torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
|
||||
-0.3782, 0.8240],
|
||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
|
||||
-0.3562, -1.4116],
|
||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
|
||||
2.0023, 0.7075],
|
||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
|
||||
0.3832, -0.1540],
|
||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
|
||||
-1.3508, -0.9513],
|
||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
|
||||
-0.0842, -0.4294]]
|
||||
|
||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
|
||||
-1.4138, -0.8853],
|
||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
|
||||
-1.0726, 0.0364],
|
||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
|
||||
-0.8836, -0.9320],
|
||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
|
||||
-1.6857, 1.1571],
|
||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
|
||||
-0.5837, 1.0184],
|
||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
|
||||
-0.9025, 0.0864]]])] * 100,
|
||||
'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4],
|
||||
[4, 1, 7, 0, 4, 7]])] * 100,
|
||||
'seq_len': [[6, 6]] * 100})
|
||||
]),
|
||||
torch.FloatTensor([
|
||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
|
||||
-1.4138, -0.8853],
|
||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
|
||||
-1.0726, 0.0364],
|
||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
|
||||
-0.8836, -0.9320],
|
||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
|
||||
-1.6857, 1.1571],
|
||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
|
||||
-0.5837, 1.0184],
|
||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
|
||||
-0.9025, 0.0864]]
|
||||
])
|
||||
],
|
||||
'tg': [
|
||||
torch.LongTensor([[3, 6, 0, 8, 2, 4]]),
|
||||
torch.LongTensor([[4, 1, 7, 0, 4, 7]])
|
||||
],
|
||||
'seq_len': [
|
||||
[6], [6]
|
||||
]})
|
||||
metric_kwargs = {
|
||||
'tag_vocab': fastnlp_bio_vocab,
|
||||
'only_gross': False,
|
||||
@ -372,7 +377,6 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
|
||||
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}
|
||||
processes = NUM_PROCESSES
|
||||
print(torch.cuda.device_count())
|
||||
|
||||
pool.starmap(
|
||||
partial(
|
||||
@ -384,3 +388,5 @@ class SpanFPreRecMetricTest(unittest.TestCase):
|
||||
),
|
||||
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)]
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import pytest
|
||||
from itertools import chain
|
||||
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler
|
||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
|
||||
@ -18,7 +18,7 @@ class TestReproducibleBatchSampler:
|
||||
before_batch_size = 7
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
|
||||
forward_steps = 3
|
||||
@ -28,15 +28,15 @@ class TestReproducibleBatchSampler:
|
||||
|
||||
# 1. 保存状态
|
||||
_get_re_batchsampler = dataloader.batch_sampler
|
||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
|
||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
|
||||
state = _get_re_batchsampler.state_dict()
|
||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
|
||||
"sampler_type": "ReproducibleBatchSampler"}
|
||||
"sampler_type": "RandomBatchSampler"}
|
||||
|
||||
# 2. 断点重训,重新生成一个 dataloader;
|
||||
# 不改变 batch_size;
|
||||
dataloader = DataLoader(dataset, batch_size=before_batch_size)
|
||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler.load_state_dict(state)
|
||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
|
||||
@ -53,7 +53,7 @@ class TestReproducibleBatchSampler:
|
||||
# 改变 batch_size;
|
||||
after_batch_size = 3
|
||||
dataloader = DataLoader(dataset, batch_size=after_batch_size)
|
||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler.load_state_dict(state)
|
||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
|
||||
@ -99,7 +99,7 @@ class TestReproducibleBatchSampler:
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
|
||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
|
||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
|
||||
@ -111,13 +111,13 @@ class TestReproducibleBatchSampler:
|
||||
|
||||
# 1. 保存状态
|
||||
_get_re_batchsampler = dataloader.batch_sampler
|
||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
|
||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
|
||||
state = _get_re_batchsampler.state_dict()
|
||||
|
||||
# 2. 断点重训,重新生成一个 dataloader;
|
||||
# 不改变 batch_size;
|
||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
|
||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
|
||||
re_batchsampler.load_state_dict(state)
|
||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
|
||||
|
||||
@ -416,7 +416,6 @@ class TestBucketedBatchSampler:
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
|
||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
|
||||
# TODO 两个 rank 上的长度是要在同一个bucket的
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
batch_size = 6
|
||||
if num_replica*batch_size > num_samples:
|
||||
|
@ -1,18 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from functools import partial
|
||||
from array import array
|
||||
from itertools import chain
|
||||
|
||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
|
||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
|
||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
|
||||
|
||||
|
||||
class TestRandomSamplerYh(unittest.TestCase):
|
||||
class TestRandomSamplerYh:
|
||||
def test_init(self):
|
||||
# 测试能否正确初始化
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase):
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
sampler = RandomSampler(dataset)
|
||||
for i in sampler:
|
||||
with self.assertRaises(AssertionError):
|
||||
with pytest.raises(AssertionError):
|
||||
sampler.set_distributed(1, 0)
|
||||
break
|
||||
|
||||
@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase):
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
sampler = RandomSampler(dataset, shuffle=False)
|
||||
sampler.set_distributed(num_replicas=2, rank=0, pad=False)
|
||||
self.assertEqual(len(sampler), 50)
|
||||
assert len(sampler)==50
|
||||
count = 0
|
||||
for i in sampler:
|
||||
self.assertEqual(i%2, 0)
|
||||
assert i%2==0
|
||||
count += 1
|
||||
self.assertEqual(count, 50)
|
||||
assert count == 50
|
||||
|
||||
sampler.set_distributed(num_replicas=2, rank=1, pad=False)
|
||||
self.assertEqual(len(sampler), 50)
|
||||
assert len(sampler)==50
|
||||
count = 0
|
||||
for i in sampler:
|
||||
self.assertEqual(i%2, 1)
|
||||
assert i%2==1
|
||||
count += 1
|
||||
self.assertEqual(count, 50)
|
||||
assert count==50
|
||||
|
||||
dataset = TorchNormalDataset(num_of_data=101)
|
||||
sampler = RandomSampler(dataset, shuffle=False)
|
||||
sampler.set_distributed(num_replicas=2, rank=0, pad=True)
|
||||
self.assertEqual(len(sampler), 51)
|
||||
assert len(sampler)==51
|
||||
count = 0
|
||||
for i in sampler:
|
||||
self.assertEqual(i%2, 0)
|
||||
assert i%2==0
|
||||
count += 1
|
||||
self.assertEqual(count, 51)
|
||||
assert count == 51
|
||||
|
||||
sampler.set_distributed(num_replicas=2, rank=1, pad=True)
|
||||
self.assertEqual(len(sampler), 51)
|
||||
assert len(sampler) == 51
|
||||
count = 0
|
||||
for i in sampler:
|
||||
if i!=0:
|
||||
self.assertEqual(i%2, 1)
|
||||
assert i%2==1
|
||||
count += 1
|
||||
self.assertEqual(count, 51)
|
||||
assert count == 51
|
||||
|
||||
def test_state_dict_check_length(self):
|
||||
dataset = TorchNormalDataset(num_of_data=100)
|
||||
@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase):
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_ds = TorchNormalDataset(num_of_data=10)
|
||||
with self.assertRaises(AssertionError):
|
||||
with pytest.raises(AssertionError):
|
||||
new_sampler = RandomSampler(new_ds)
|
||||
new_sampler.load_state_dict(states)
|
||||
|
||||
@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase):
|
||||
new_sampler = RandomSampler(new_ds)
|
||||
new_sampler.load_state_dict(states)
|
||||
|
||||
def test_state_dict(self):
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('pre_shuffle', [True, False])
|
||||
@pytest.mark.parametrize('post_shuffle', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
|
||||
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
|
||||
num_samples = 100
|
||||
dataset = TorchNormalDataset(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
lst = [0]+np.random.randint(1, num_samples, size=3).tolist()
|
||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
|
||||
lst):
|
||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle)
|
||||
sampler.set_epoch(0)
|
||||
already_numbers = set()
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
self.assertEqual(len(already_numbers), num_consumed_samples)
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle)
|
||||
sampler.set_epoch(0)
|
||||
already_numbers = set()
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples
|
||||
|
||||
states = sampler.state_dict()
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
self.assertNotIn(i, already_numbers)
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
|
||||
new_sampler.set_epoch(0)
|
||||
count = 0
|
||||
for i in new_sampler:
|
||||
self.assertNotIn(i, other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
self.assertNotIn(i, already_numbers)
|
||||
count += 1
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
new_sampler.set_epoch(0)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
for i in new_sampler:
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=1 # 因为pad可能重复
|
||||
|
||||
def test_state_dict_2(self):
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('pre_shuffle', [True, False])
|
||||
@pytest.mark.parametrize('post_shuffle', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
|
||||
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
|
||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
|
||||
num_samples = 100
|
||||
dataset = TorchNormalDataset(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist()
|
||||
# lst = [30]
|
||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
|
||||
lst):
|
||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
|
||||
already_numbers = set()
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
|
||||
sampler.set_distributed(num_replicas=2, rank=0)
|
||||
sampler.set_epoch(0)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replicas=2, rank=1)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
self.assertEqual(len(already_numbers), num_consumed_samples*2)
|
||||
already_numbers = set()
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
|
||||
sampler.set_distributed(num_replicas=2, rank=0)
|
||||
sampler.set_epoch(0)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replicas=2, rank=1)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples*2
|
||||
|
||||
states = sampler.state_dict()
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
self.assertNotIn(i, already_numbers)
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
|
||||
count = 0
|
||||
for i in new_sampler:
|
||||
self.assertNotIn(i, other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
self.assertNotIn(i, already_numbers)
|
||||
count += 1
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
for i in new_sampler:
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=1 # 因为pad可能重复
|
||||
|
||||
|
||||
class TestRandomSampler(unittest.TestCase):
|
||||
class TestRandomSampler:
|
||||
# 测试单卡;
|
||||
def test_seed_work_when_shuffle_is_true(self):
|
||||
data_length = 100
|
||||
@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase):
|
||||
...
|
||||
|
||||
|
||||
class DatasetWithVaryLength:
|
||||
def __init__(self, num_of_data=100, reverse=False):
|
||||
self.data = np.arange(num_of_data)
|
||||
if reverse:
|
||||
self.data = self.data[::-1]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.data[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class TestSortedSampler:
|
||||
def test_single(self):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = SortedSampler(data, length=data.data)
|
||||
indexes = list(sampler)
|
||||
assert indexes==list(range(num_of_data-1, -1, -1))
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
def test_multi(self, pad, num_replica, num_of_data):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = SortedSampler(dataset=data, length=data.data)
|
||||
sampler.set_distributed(num_replica, rank=i, pad=pad)
|
||||
samplers.append(sampler)
|
||||
|
||||
# 保证顺序是没乱的
|
||||
already_seen_index = set()
|
||||
for sampler in samplers:
|
||||
larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。
|
||||
prev_index = float('inf')
|
||||
cur_set = set()
|
||||
seen_in_other_rank = 0
|
||||
for index in sampler:
|
||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
|
||||
cur_set.add(index)
|
||||
larger_count += int(index <= prev_index)
|
||||
prev_index = index
|
||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
|
||||
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
|
||||
already_seen_index.update(cur_set)
|
||||
|
||||
indexes = list(chain(*samplers))
|
||||
indexes = set(indexes)
|
||||
if pad:
|
||||
assert indexes == set(range(num_of_data))
|
||||
else:
|
||||
assert len(indexes) <= num_of_data
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
|
||||
def test_state_dict(self, pad, num_consumed_samples):
|
||||
num_samples = 100
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
sampler = SortedSampler(dataset, length=dataset.data)
|
||||
sampler.set_epoch(0)
|
||||
already_numbers = set()
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
if already_numbers:
|
||||
assert j<max(already_numbers)
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples
|
||||
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = SortedSampler(dataset, length=dataset.data)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
assert i < max(already_numbers)
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = SortedSampler(dataset, length=dataset.data)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
new_sampler.set_epoch(0)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
smaller = 0
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
smaller += int(i >= max(already_numbers))
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=1 # 因为pad可能重复
|
||||
assert smaller<=1 if pad else smaller==0
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
|
||||
def test_state_dict_2(self, pad, num_consumed_samples):
|
||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
|
||||
num_samples = 100
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
# lst = [30]
|
||||
already_numbers = set()
|
||||
sampler = SortedSampler(dataset, length=dataset.data)
|
||||
sampler.set_distributed(num_replicas=2, rank=0)
|
||||
sampler.set_epoch(0)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
if already_numbers:
|
||||
assert j<=max(already_numbers)
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
sampler = SortedSampler(dataset, length=dataset.data)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replicas=2, rank=1)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples*2
|
||||
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = SortedSampler(dataset, length=dataset.data)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
assert i < max(already_numbers)
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = SortedSampler(dataset, length=dataset.data)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
smaller = 0
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
smaller += int(i>=max(already_numbers))
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=1 # 因为pad可能重复
|
||||
assert smaller <= 1 if pad else smaller == 0
|
||||
|
||||
|
||||
class TestSequentialSampler:
|
||||
def test_single(self):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = SequentialSampler(data)
|
||||
indexes = list(sampler)
|
||||
assert indexes==list(range(num_of_data))
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
def test_multi(self, pad, num_replica, num_of_data):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = SequentialSampler(dataset=data)
|
||||
sampler.set_distributed(num_replica, rank=i, pad=pad)
|
||||
samplers.append(sampler)
|
||||
|
||||
# 保证顺序是没乱的
|
||||
already_seen_index = set()
|
||||
for idx, sampler in enumerate(samplers):
|
||||
larger_count = 1
|
||||
prev_index = float('inf')
|
||||
cur_set = set()
|
||||
seen_in_other_rank = 0
|
||||
for index in sampler:
|
||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
|
||||
cur_set.add(index)
|
||||
larger_count += int(index >= prev_index)
|
||||
prev_index = index
|
||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
|
||||
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
|
||||
already_seen_index.update(cur_set)
|
||||
|
||||
indexes = list(chain(*samplers))
|
||||
indexes = set(indexes)
|
||||
if pad:
|
||||
assert indexes == set(range(num_of_data))
|
||||
else:
|
||||
assert len(indexes) <= num_of_data
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
|
||||
def test_state_dict(self, pad, num_consumed_samples):
|
||||
num_samples = 100
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
sampler = SequentialSampler(dataset=dataset)
|
||||
sampler.set_epoch(0)
|
||||
already_numbers = set()
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
if already_numbers:
|
||||
assert j>max(already_numbers)
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples
|
||||
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = SequentialSampler(dataset=dataset)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
assert i > max(already_numbers)
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = SequentialSampler(dataset=dataset)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
new_sampler.set_epoch(0)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
smaller = 0
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
smaller += int(i <= max(already_numbers))
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=rank # 因为pad可能重复
|
||||
assert smaller<=1 if pad else smaller==0
|
||||
|
||||
@pytest.mark.parametrize('pad', [True, False])
|
||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
|
||||
def test_state_dict_2(self, pad, num_consumed_samples):
|
||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
|
||||
num_samples = 100
|
||||
dataset = DatasetWithVaryLength(num_of_data=num_samples)
|
||||
# 测试使用 前后shuffle不一致的load操作
|
||||
# lst = [30]
|
||||
already_numbers = set()
|
||||
sampler = SequentialSampler(dataset=dataset)
|
||||
sampler.set_distributed(num_replicas=2, rank=0)
|
||||
sampler.set_epoch(0)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
if already_numbers:
|
||||
assert j>max(already_numbers)
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
sampler = SequentialSampler(dataset=dataset)
|
||||
sampler.set_epoch(0)
|
||||
sampler.set_distributed(num_replicas=2, rank=1)
|
||||
if num_consumed_samples>0:
|
||||
for i, j in enumerate(sampler, start=1):
|
||||
already_numbers.add(j)
|
||||
if i == num_consumed_samples:
|
||||
break
|
||||
assert len(already_numbers) == num_consumed_samples*2
|
||||
|
||||
states = sampler.state_dict()
|
||||
|
||||
new_sampler = SequentialSampler(dataset=dataset)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
assert i > max(already_numbers)
|
||||
assert i not in already_numbers
|
||||
|
||||
# 测试切换成多卡也没有问题
|
||||
other_rank_number = set()
|
||||
for rank in range(3):
|
||||
new_sampler = SequentialSampler(dataset=dataset)
|
||||
new_sampler.load_state_dict(states)
|
||||
new_sampler.set_epoch(0)
|
||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
|
||||
count = 0
|
||||
seen = 0
|
||||
seen_in_other_rank = 0
|
||||
smaller = 0
|
||||
for i in new_sampler:
|
||||
if already_numbers:
|
||||
smaller += int(i<max(already_numbers))
|
||||
seen_in_other_rank += int(i in other_rank_number)
|
||||
other_rank_number.add(i)
|
||||
seen += int(i in already_numbers)
|
||||
count += 1
|
||||
assert seen <= 1 if pad else seen == 0
|
||||
assert seen_in_other_rank<=1 # 因为pad可能重复
|
||||
assert smaller <= rank if pad else smaller == 0
|
||||
|
||||
|
||||
|
104
tests/core/samplers/test_unrepeated_sampler.py
Normal file
104
tests/core/samplers/test_unrepeated_sampler.py
Normal file
@ -0,0 +1,104 @@
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
|
||||
|
||||
|
||||
class DatasetWithVaryLength:
|
||||
def __init__(self, num_of_data=100):
|
||||
self.data = list(range(num_of_data))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.data[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class TestUnrepeatedSampler:
|
||||
@pytest.mark.parametrize('shuffle', [True, False])
|
||||
def test_single(self, shuffle):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = UnrepeatedRandomSampler(data, shuffle)
|
||||
indexes = set(sampler)
|
||||
assert indexes==set(range(num_of_data))
|
||||
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
@pytest.mark.parametrize('shuffle', [False, True])
|
||||
def test_multi(self, num_replica, num_of_data, shuffle):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
|
||||
sampler.set_distributed(num_replica, rank=i)
|
||||
samplers.append(sampler)
|
||||
|
||||
indexes = list(chain(*samplers))
|
||||
assert len(indexes) == num_of_data
|
||||
indexes = set(indexes)
|
||||
assert indexes==set(range(num_of_data))
|
||||
|
||||
|
||||
class TestUnrepeatedSortedSampler:
|
||||
def test_single(self):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = UnrepeatedSortedSampler(data, length=data.data)
|
||||
indexes = list(sampler)
|
||||
assert indexes==list(range(num_of_data-1, -1, -1))
|
||||
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
def test_multi(self, num_replica, num_of_data):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
|
||||
sampler.set_distributed(num_replica, rank=i)
|
||||
samplers.append(sampler)
|
||||
|
||||
# 保证顺序是没乱的
|
||||
for sampler in samplers:
|
||||
prev_index = float('inf')
|
||||
for index in sampler:
|
||||
assert index <= prev_index
|
||||
prev_index = index
|
||||
|
||||
indexes = list(chain(*samplers))
|
||||
assert len(indexes) == num_of_data # 不同卡之间没有交叉
|
||||
indexes = set(indexes)
|
||||
assert indexes==set(range(num_of_data))
|
||||
|
||||
|
||||
class TestUnrepeatedSequentialSampler:
|
||||
def test_single(self):
|
||||
num_of_data = 100
|
||||
data = DatasetWithVaryLength(num_of_data)
|
||||
sampler = UnrepeatedSequentialSampler(data, length=data.data)
|
||||
indexes = list(sampler)
|
||||
assert indexes==list(range(num_of_data))
|
||||
|
||||
@pytest.mark.parametrize('num_replica', [2, 3])
|
||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
|
||||
def test_multi(self, num_replica, num_of_data):
|
||||
data = DatasetWithVaryLength(num_of_data=num_of_data)
|
||||
samplers = []
|
||||
for i in range(num_replica):
|
||||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
|
||||
sampler.set_distributed(num_replica, rank=i)
|
||||
samplers.append(sampler)
|
||||
|
||||
# 保证顺序是没乱的
|
||||
for sampler in samplers:
|
||||
prev_index = float('-inf')
|
||||
for index in sampler:
|
||||
assert index>=prev_index
|
||||
prev_index = index
|
||||
|
||||
indexes = list(chain(*samplers))
|
||||
assert len(indexes) == num_of_data
|
||||
indexes = set(indexes)
|
||||
assert indexes == set(range(num_of_data))
|
Loading…
Reference in New Issue
Block a user