fastNLP/tutorials/fastnlp_tutorial_5.ipynb
2022-06-04 21:15:44 +08:00

1250 lines
58 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "fdd7ff16",
"metadata": {},
"source": [
"# T5. trainer 和 evaluator 的深入介绍\n",
"\n",
"  1   fastNLP 中 driver 的补充介绍\n",
" \n",
"    1.1   trainer 和 driver 的构想 \n",
"\n",
"    1.2   device 与 多卡训练\n",
"\n",
"  2   fastNLP 中的更多 metric 类型\n",
"\n",
"    2.1   预定义的 metric 类型\n",
"\n",
"    2.2   自定义的 metric 类型\n",
"\n",
"  3   fastNLP 中 trainer 的补充介绍\n",
"\n",
"    3.1   trainer 的内部结构"
]
},
{
"cell_type": "markdown",
"id": "08752c5a",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## 1. fastNLP 中 driver 的补充介绍\n",
"\n",
"### 1.1 trainer 和 driver 的构想\n",
"\n",
"在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**\n",
"\n",
"  在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n",
"\n",
"    **`evaluator`封装评测的`metric`****`trainer`封装训练的`optimizer`****也可以包括`evaluator`**\n",
"\n",
"之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架****例如`pytorch`、`paddle`、`jittor`的兼容**\n",
"\n",
"  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
"\n",
"    划分为**框架无关的循环控制、批量分发部分****由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n",
"\n",
"    以及**随框架不同的模型调用、数值优化部分****由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n",
"\n",
"| <div align=\"center\">训练过程</div> | <div align=\"center\">框架无关 对应`trainer`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
"|:--|:--|:--|\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for epoch in 1:n_eoochs:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for epoch in 1:n_eoochs:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">for step in 1:total_steps:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:60px;\">batch = fetch_batch()</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss.backward()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss.backward()</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.clear_grad()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.clear_grad()</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.update()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.update()</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">if need_save:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">if need_save:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.save()</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.save()</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> | |"
]
},
{
"cell_type": "markdown",
"id": "3e55f07b",
"metadata": {},
"source": [
"&emsp; 对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
"\n",
"&emsp; &emsp; 划分为**框架无关的循环控制、分发汇总部分****由`evaluator`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n",
"\n",
"&emsp; &emsp; 以及**随框架不同的模型调用、评测计算部分**,同样**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n",
"\n",
"| <div align=\"center\">评测过程</div> | <div align=\"center\">框架无关 对应`evaluator`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
"|:--|:--|:--|\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">model.set_eval()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">model.set_eval()</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for step in 1:total_steps:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">batch = fetch_batch()</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">metric.compute(batch, outputs)</div> | | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">metric.compute(batch, outputs)</div> |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">results = metric.get_metric()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">results = metric.get_metric()</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> | |\n",
"| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> | |"
]
},
{
"cell_type": "markdown",
"id": "94ba11c6",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"source": [
"由此,从程序员的角度,`fastNLP v0.8`**通过一个`driver`让基于`pytorch`、`paddle`、`jittor`框架的模型**\n",
"\n",
"&emsp; &emsp; **都能在相同的`trainer`和`evaluator`上运行**,这也**是`fastNLP v0.8`相比于之前版本的一大亮点**\n",
"\n",
"&emsp; 而从`driver`的角度,`fastNLP v0.8`通过定义一个`driver`基类,**将所有张量转化为`numpy.tensor`**\n",
"\n",
"&emsp; &emsp; 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
"\n",
"&emsp; &emsp; 对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
]
},
{
"cell_type": "markdown",
"id": "ab1cea7d",
"metadata": {},
"source": [
"### 1.2 device 与 多卡训练\n",
"\n",
"**`fastNLP v0.8`支持多卡训练**,实现方法则是**通过将`trainer`中的`device`设置为对应显卡的序号列表**\n",
"\n",
"&emsp; 由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v0.8`保证:\n",
"\n",
"&emsp; &emsp; 数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
"\n",
"&emsp; &emsp; 模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
"\n",
"&emsp; 例如,在评测计算运行`get_metric`函数时,`fastNLP v0.8`将自动按照`self.right`和`self.total`\n",
"\n",
"&emsp; &emsp; 指定的**`aggregate_method`方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
"\n",
"&emsp; &emsp; 在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
" \n",
"```python\n",
"trainer = Trainer(\n",
" model=model, # model 基于 pytorch 实现 \n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
" ...\n",
" driver='torch', # driver 使用 torch_driver \n",
" device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n",
" ...\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()},\n",
" ...\n",
" )\n",
"\n",
"class Accuracy(Metric):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.register_element(name='total', value=0, aggregate_method='sum')\n",
" self.register_element(name='right', value=0, aggregate_method='sum')\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "e2e0a210",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"source": [
"注:`fastNLP v0.8`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
]
},
{
"cell_type": "markdown",
"id": "8d19220c",
"metadata": {},
"source": [
"## 2. fastNLP 中的更多 metric 类型\n",
"\n",
"### 2.1 预定义的 metric 类型\n",
"\n",
"在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评测标准`metric`**\n",
"\n",
"&emsp; 包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
"\n",
"&emsp; &emsp; **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括召回率`Pre`、精确率`Rec`\n",
"\n",
"&emsp; &emsp; **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n",
"\n",
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
"|:--|:--|:--|\n",
"| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
"| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
"| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
"| `ClassifyFPreRecMetric` | 召回率、精确率、F1值适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
"| `SpanFPreRecMetric` | 召回率、精确率、F1值适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
]
},
{
"cell_type": "markdown",
"id": "fdc083a3",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"source": [
"&emsp; 如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
"\n",
"&emsp; &emsp; **`update`函数更新单个`batch`的统计量****`get_metric`函数返回最终结果**,并打印显示\n",
"\n",
"\n",
"### 2.1.1 Accuracy 与 TransformersAccuracy\n",
"\n",
"`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
"\n",
"&emsp; `get_metric`函数打印格式为 **`{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}`**\n",
"\n",
"&emsp; 一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
"\n",
"&emsp; **`update`函数的参数包括`pred`、`target`、`seq_len`****后者用来标记批次中每笔数据的长度**\n",
"\n",
"`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
"\n",
"&emsp; 在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
"\n",
"\n",
"### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
"\n",
"`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
"\n",
"&emsp; 两者的相同之处在于:**第一****都包括召回率/查全率`Rec`**、**精确率/查准率`Pre`**、**`F1`值**这三个指标\n",
"\n",
"&emsp; &emsp; `get_metric`函数打印格式为 **`{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}`**\n",
"\n",
"&emsp; &emsp; 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
"\n",
"$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
"\n",
"$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
"\n",
"&emsp; **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
"\n",
"&emsp; &emsp; **`micro F1`****直接统计所有类别的`Rec-Pre-F1`**)、**`macro F1`****统计各类别的`Rec-Pre-F1`再算术平均**\n",
"\n",
"&emsp; **第三**,两者在初始化时还可以**传入基于`fastNLP.Vocabulary`的`tag_vocab`参数记录数据集中的标签序号**\n",
"\n",
"&emsp; &emsp; **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
"\n",
"两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
"\n",
"&emsp; &emsp; **`SpanFPreRecMetric`针对更复杂的抽取问题****规定标签`B-xx`和`I-xx`或`B-xx`和`E-xx`构成标签对**\n",
"\n",
"&emsp; 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
"\n",
"&emsp; &emsp; 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
"\n",
"&emsp; &emsp; 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
"\n",
"&emsp; &emsp; &emsp; 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
"\n",
"&emsp; &emsp; 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
"\n",
"```python\n",
"from fastNLP import Vocabulary\n",
"from fastNLP import ClassifyFPreRecMetric\n",
"\n",
"tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n",
"tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
" 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
" 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
" 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
" 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n",
"ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
"\n",
"FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n",
" ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n",
" only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n",
" f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
"metrics = {'F1': FPreRec}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "8a22f522",
"metadata": {},
"source": [
"### 2.2 自定义的 metric 类型\n",
"\n",
"如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的`metric`类型**\n",
"\n",
"&emsp; &emsp; 也**需要继承自`Metric`类**,同时**内部自定义好`__init__`、`update`和`get_metric`函数**\n",
"\n",
"&emsp; 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
"\n",
"&emsp; 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**`update`的参数名**\n",
"\n",
"&emsp; &emsp; **需要待评估模型在`evaluate_step`中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
"\n",
"&emsp; &emsp; 在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
"\n",
"&emsp; &emsp; 此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
"\n",
"&emsp; 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
"\n",
"&emsp; &emsp; 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
"\n",
"根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "08a872e9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import Metric\n",
"\n",
"class MyMetric(Metric):\n",
"\n",
" def __init__(self):\n",
" Metric.__init__(self)\n",
" self.total_num = 0\n",
" self.right_num = 0\n",
"\n",
" def update(self, pred, target):\n",
" self.total_num += target.size(0)\n",
" self.right_num += target.eq(pred).sum().item()\n",
"\n",
" def get_metric(self, reset=True):\n",
" acc = self.right_num / self.total_num\n",
" if reset:\n",
" self.total_num = 0\n",
" self.right_num = 0\n",
" return {'prefix': acc}"
]
},
{
"cell_type": "markdown",
"id": "0155f447",
"metadata": {},
"source": [
"&emsp; 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5ad81ac7",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ef923b90b19847f4916cccda5d33fc36",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"sst2data = load_dataset('glue', 'sst2')"
]
},
{
"cell_type": "markdown",
"id": "e9d81760",
"metadata": {},
"source": [
"&emsp; 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n",
"\n",
"&emsp; &emsp; 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfb28b1b",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/6000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from fastNLP import DataSet\n",
"\n",
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
"\n",
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n",
"dataset.delete_field('sentence')\n",
"dataset.delete_field('idx')\n",
"\n",
"from fastNLP import Vocabulary\n",
"\n",
"vocab = Vocabulary()\n",
"vocab.from_dataset(dataset, field_name='words')\n",
"vocab.index_dataset(dataset, field_name='words')\n",
"\n",
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
"\n",
"from fastNLP import prepare_torch_dataloader\n",
"\n",
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
]
},
{
"cell_type": "markdown",
"id": "af3f8c63",
"metadata": {},
"source": [
"&emsp; 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2fd210c5",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models.torch import CNNText\n",
"\n",
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{
"cell_type": "markdown",
"id": "6e723b87",
"metadata": {},
"source": [
"## 3. fastNLP 中 trainer 的补充介绍\n",
"\n",
"### 3.1 trainer 的内部结构\n",
"\n",
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
"\n",
"&emsp; 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
"\n",
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
"|:--|:--:|:--:|:--|:--|\n",
"| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n",
"| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
"| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n",
"| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
"| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n",
"| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n",
"| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n",
"| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n",
"| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n",
"| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n",
"| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n",
"| **`train_dataloader`** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n",
"| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n",
"| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n",
"| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n",
"| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n",
"| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n",
"| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n",
"| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n",
"| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
"| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n",
"| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n",
"| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n",
"| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n",
"| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n",
"| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n",
"| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |"
]
},
{
"cell_type": "markdown",
"id": "9e13ee08",
"metadata": {},
"source": [
"其中,**`input_mapping`和`output_mapping`** 定义形式如下:输入字典形式的数据,根据参数匹配要求\n",
"\n",
"&emsp; 调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de96c1d1",
"metadata": {},
"outputs": [],
"source": [
"def input_mapping(data):\n",
" data['target'] = data['label']\n",
" return data"
]
},
{
"cell_type": "markdown",
"id": "2fc8b9f3",
"metadata": {},
"source": [
"&emsp; 而`trainer`模块的基础方法列表如下,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n",
"\n",
"| <div align=\"center\">名称</div> |<div align=\"center\">功能</div> | <div align=\"center\">主要参数</div> |\n",
"|:--|:--|:--|\n",
"| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
"| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
"| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
"| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
"| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
"| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n",
"| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n",
"| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n",
"| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n",
"| `save_checkpoint` | <div style=\"line-height:25px;\">保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar`</div> | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n",
"| `load_checkpoint` | <div style=\"line-height:25px;\">加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar`</div> | <div style=\"line-height:25px;\">`folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True`<br>`resume_training`指明是否只精确到上次训练的批量,默认`True`</div> |\n",
"| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n",
"| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
"\n",
"<!-- ```python\n",
"Trainer.__init__():\n",
"\ton_after_trainer_initialized(trainer, driver)\n",
"Trainer.run():\n",
"\tif num_eval_sanity_batch > 0: # 如果设置了 num_eval_sanity_batch\n",
"\t\ton_sanity_check_begin(trainer)\n",
"\t\ton_sanity_check_end(trainer, sanity_check_res)\n",
"\ttry:\n",
"\t\ton_train_begin(trainer)\n",
"\t\twhile cur_epoch_idx < n_epochs:\n",
"\t\t\ton_train_epoch_begin(trainer)\n",
"\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n",
"\t\t\t\ton_fetch_data_begin(trainer)\n",
"\t\t\t\tbatch = next(dataloader)\n",
"\t\t\t\ton_fetch_data_end(trainer)\n",
"\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n",
"\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n",
"\t\t\t\ton_after_backward(trainer)\n",
"\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_train_batch_end(trainer)\n",
"\t\t\ton_train_epoch_end(trainer)\n",
"\texcept BaseException:\n",
"\t\tself.on_exception(trainer, exception)\n",
"\tfinally:\n",
"\t\ton_train_end(trainer)\n",
"``` -->"
]
},
{
"cell_type": "markdown",
"id": "1e21df35",
"metadata": {},
"source": [
"紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n",
"\n",
"&emsp; 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "926a9c50",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
" device=0, # 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" input_mapping=input_mapping,\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'suffix': MyMetric()}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "b1b2e8b7",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"source": [
"最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n",
"\n",
"| <div align=\"center\">名称</div> | <div align=\"center\">功能</div> | <div align=\"center\">默认值</div> |\n",
"|:--|:--|:--|\n",
"| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n",
"| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n",
"| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n",
"| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n",
"| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "43be274f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:30:35] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
"</pre>\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
"</pre>\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6875</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.825</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1abfa0a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}