fastNLP/tutorials/fastnlp_tutorial_0.ipynb
2022-06-01 23:15:21 +08:00

1566 lines
62 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "aec0fde7",
"metadata": {},
"source": [
"# T0. trainer 和 evaluator 的基本使用\n",
"\n",
"  1   trainer 和 evaluator 的基本关系\n",
" \n",
"    1.1   trainer 和 evaluater 的初始化\n",
"\n",
"    1.2   driver 的含义与使用要求\n",
"\n",
"    1.3   trainer 内部初始化 evaluater\n",
"\n",
"  2   使用 fastNLP 搭建 argmax 模型\n",
"\n",
"    2.1   trainer_step 和 evaluator_step\n",
"\n",
"    2.2   trainer 和 evaluator 的参数匹配\n",
"\n",
"    2.3   示例argmax 模型的搭建\n",
"\n",
"  3   使用 fastNLP 训练 argmax 模型\n",
" \n",
"    3.1   trainer 外部初始化的 evaluator\n",
"\n",
"    3.2   trainer 内部初始化的 evaluator "
]
},
{
"cell_type": "markdown",
"id": "09ea669a",
"metadata": {},
"source": [
"## 1. trainer 和 evaluator 的基本关系\n",
"\n",
"### 1.1 trainer 和 evaluator 的初始化\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
"\n",
"  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
"\n",
"在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
"\n",
"  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`\n",
"\n",
"\n",
"```python\n",
"trainer = Trainer(\n",
" model=model, # 模型基于 torch.nn.Module\n",
" train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n",
" optimizers=optimizer, # 优化模块基于 torch.optim.*\n",
"\t...\n",
"\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n",
"\tdevice='cuda', # 使用 GPU0 显卡执行训练\n",
"\t...\n",
")\n",
"...\n",
"evaluator = Evaluator(\n",
" model=model, # 模型基于 torch.nn.Module\n",
" dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n",
" metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
" ...\n",
" driver=trainer.driver, # 保持同 trainer 的 driver 一致\n",
"\tdevice=None,\n",
" ...\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "3c11fe1a",
"metadata": {},
"source": [
"### 1.2 driver 的含义与使用要求\n",
"\n",
"在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
"\n",
"  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
"\n",
"  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n",
"\n",
"注:这里给出一条建议:**在同一脚本中****所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n",
"\n",
"  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n",
"\n",
"  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦"
]
},
{
"cell_type": "markdown",
"id": "2cac4a1a",
"metadata": {},
"source": [
"### 1.3 Trainer 内部初始化 Evaluator\n",
"\n",
"在`fastNLP 0.8`中,如果在**初始化`Trainer`时****传入参数`evaluator_dataloaders`和`metrics`**\n",
"\n",
"  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
"\n",
"```python\n",
"trainer = Trainer(\n",
" model=model,\n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
"\t...\n",
"\tdriver=\"torch\",\n",
"\tdevice='cuda',\n",
"\t...\n",
" evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n",
" metrics={'acc': Accuracy()}, # 传入参数 metrics\n",
"\t...\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "0c9c7dda",
"metadata": {},
"source": [
"## 2. argmax 模型的搭建实例"
]
},
{
"cell_type": "markdown",
"id": "524ac200",
"metadata": {},
"source": [
"### 2.1 trainer_step 和 evaluator_step\n",
"\n",
"在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n",
"\n",
"  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n",
"\n",
"```python\n",
"class Model(torch.nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.loss_fn = torch.nn.CrossEntropyLoss()\n",
" pass\n",
"\n",
" def forward(self, x):\n",
" pass\n",
"\n",
" def train_step(self, x, y):\n",
" pred = self(x)\n",
" return {\"loss\": self.loss_fn(pred, y)}\n",
"\n",
" def evaluate_step(self, x, y):\n",
" pred = self(x)\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": y}\n",
"```\n",
"***\n",
"在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
"\n",
"  由于,在`Trainer`训练时,**`Trainer`通过参数`train_fn`对应的模型方法获得当前数据批次的损失值**\n",
"\n",
"  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
"\n",
"    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
"\n",
"注:在`fastNLP 0.8`中,**`Trainer`要求模型通过`train_step`来返回一个字典****满足如`{\"loss\": loss}`的形式**\n",
"\n",
"  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换详见trainer的详细讲解待补充\n",
"\n",
"同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
"\n",
"  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
"\n",
"  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
"\n",
"  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n",
"\n",
"<img src=\"./figures/T0-fig-training-structure.png\" width=\"68%\" height=\"68%\" align=\"center\"></img>"
]
},
{
"cell_type": "markdown",
"id": "fb3272eb",
"metadata": {},
"source": [
"### 2.2 trainer 和 evaluator 的参数匹配\n",
"\n",
"在`fastNLP 0.8`中,参数匹配涉及到两个方面,分别是在\n",
"\n",
"&emsp; 一方面,**在模型的前向传播中****`dataloader`向`train_step`或`evaluate_step`函数传递`batch`**\n",
"\n",
"&emsp; 另方面,**在模型的评测过程中****`evaluate_dataloader`向`metric`的`update`函数传递`batch`**\n",
"\n",
"对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
"\n",
"&emsp; &emsp; **`fastNLP 0.8`要求`dataloader`生成的每个`batch`****满足如`{\"x\": x, \"y\": y}`的形式**\n",
"\n",
"&emsp; 同时,`fastNLP 0.8`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n",
"\n",
"&emsp; &emsp; **字典形式的定义****对应在`Dataset`定义的`__getitem__`方法中**,例如下方的`ArgMaxDatset`\n",
"\n",
"&emsp; 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
"\n",
"&emsp; &emsp; `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
"\n",
"```python\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y):\n",
" self.x = x\n",
" self.y = y\n",
"\n",
" def __len__(self):\n",
" return len(self.x)\n",
"\n",
" def __getitem__(self, item):\n",
" return {\"x\": self.x[item], \"y\": self.y[item]}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "f5f1a6aa",
"metadata": {},
"source": [
"对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n",
"\n",
"&emsp; &emsp; **`update`函数****针对一个`batch`的预测结果**,计算其累计的评价指标\n",
"\n",
"&emsp; &emsp; **`get_metric`函数****统计`update`函数累计的评价指标**,来计算最终的评价结果\n",
"\n",
"&emsp; 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
"\n",
"&emsp; &emsp; 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
"\n",
"&emsp; 在此基础上,**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n",
"\n",
"&emsp; &emsp; **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名\n",
"\n",
"<img src=\"./figures/T0-fig-parameter-matching.png\" width=\"75%\" height=\"75%\" align=\"center\"></img>"
]
},
{
"cell_type": "markdown",
"id": "f62b7bb1",
"metadata": {},
"source": [
"### 2.3 示例argmax 模型的搭建\n",
"\n",
"下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
"\n",
"&emsp; 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5314482b",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class ArgMaxModel(nn.Module):\n",
" def __init__(self, num_labels, feature_dimension):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
"\n",
" self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
" self.ac1 = nn.ReLU()\n",
" self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
" self.ac2 = nn.ReLU()\n",
" self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, x):\n",
" pred = self.ac1(self.linear1(x))\n",
" pred = self.ac2(self.linear2(pred))\n",
" pred = self.output(pred)\n",
" return pred\n",
"\n",
" def train_step(self, x, y):\n",
" pred = self(x)\n",
" return {\"loss\": self.loss_fn(pred, y)}\n",
"\n",
" def evaluate_step(self, x, y):\n",
" pred = self(x)\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": y}"
]
},
{
"cell_type": "markdown",
"id": "71f3fa6b",
"metadata": {},
"source": [
"&emsp; 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
"\n",
"&emsp; &emsp; 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
"\n",
"&emsp; &emsp; 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fe612e61",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"class ArgMaxDataset(Dataset):\n",
" def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
" self.num_labels = feature_dimension\n",
" self.feature_dimension = feature_dimension\n",
" self.data_num = data_num\n",
" self.seed = seed\n",
"\n",
" g = torch.Generator()\n",
" g.manual_seed(1000)\n",
" self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
" self.y = torch.max(self.x, dim=-1)[1]\n",
"\n",
" def __len__(self):\n",
" return self.data_num\n",
"\n",
" def __getitem__(self, item):\n",
" return {\"x\": self.x[item], \"y\": self.y[item]}"
]
},
{
"cell_type": "markdown",
"id": "2cb96332",
"metadata": {},
"source": [
"&emsp; 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
"\n",
"&emsp; &emsp; 再根据`ArgMaxDataset`类初始化两个数据集实例分别用来模型测试和模型评测数据量各1000笔"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "76172ef8",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
"\n",
"train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
"evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
]
},
{
"cell_type": "markdown",
"id": "4e7d25ee",
"metadata": {},
"source": [
"&emsp; 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块批量大小同为8分别用于训练和测评"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "363b5b09",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
"evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
]
},
{
"cell_type": "markdown",
"id": "c8d4443f",
"metadata": {},
"source": [
"&emsp; 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc28a2d9",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"from torch.optim import SGD\n",
"\n",
"optimizer = SGD(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "markdown",
"id": "eb8ca6cf",
"metadata": {},
"source": [
"## 3. 使用 fastNLP 0.8 训练 argmax 模型\n",
"\n",
"### 3.1 trainer 外部初始化的 evaluator"
]
},
{
"cell_type": "markdown",
"id": "55145553",
"metadata": {},
"source": [
"通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n",
"\n",
"&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
"\n",
"&emsp; 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
"\n",
"&emsp; &emsp; 但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n",
"\n",
"&emsp; 通过`n_epochs`设定优化迭代轮数默认为20全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b51b7a2d",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver=\"torch\",\n",
" device='cuda',\n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
" n_epochs=10, # 设定迭代轮数 \n",
" progress_bar=\"auto\" # 设定进度条格式\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6e202d6e",
"metadata": {},
"source": [
"通过使用`Trainer`类的`run`函数,进行训练\n",
"\n",
"&emsp; 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n",
"\n",
"&emsp; `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ba047ead",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run()"
]
},
{
"cell_type": "markdown",
"id": "c16c5fa4",
"metadata": {},
"source": [
"通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n",
"\n",
"&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
"\n",
"&emsp; 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
"\n",
"&emsp; 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c6b6b36",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"from fastNLP import Evaluator\n",
"from fastNLP.core.metrics import Accuracy\n",
"\n",
"evaluator = Evaluator(\n",
" model=model,\n",
" driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n",
" device=None,\n",
" dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n",
")"
]
},
{
"cell_type": "markdown",
"id": "8157bb9b",
"metadata": {},
"source": [
"通过使用`Evaluator`类的`run`函数,进行训练\n",
"\n",
"&emsp; 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
"\n",
"&emsp; 最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f7cb0165",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span><span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluator.run()"
]
},
{
"cell_type": "markdown",
"id": "dd9f68fa",
"metadata": {},
"source": [
"### 3.2 trainer 内部初始化的 evaluator \n",
"\n",
"通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
"\n",
"&emsp; 通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n",
"\n",
"&emsp; 但是中间的评估结果仍会保留;**通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n",
"\n",
"&emsp; &emsp; **为负数时****表示每隔几个`epoch`评估一次****为正数时****则表示每隔几个`batch`评估一次**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "183c7d19",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()},\n",
" optimizers=optimizer,\n",
" n_epochs=10, \n",
" evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n",
")"
]
},
{
"cell_type": "markdown",
"id": "714cc404",
"metadata": {},
"source": [
"通过使用`Trainer`类的`run`函数,进行训练\n",
"\n",
"&emsp; 还可以通过**参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测****默认为`2`**\n",
"\n",
"&emsp; 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2e4daa2c",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[18:28:25] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.33</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">33.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">34.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.37</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c4e9c619",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "db784d5b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__annotations__',\n",
" '__class__',\n",
" '__delattr__',\n",
" '__dict__',\n",
" '__dir__',\n",
" '__doc__',\n",
" '__eq__',\n",
" '__format__',\n",
" '__ge__',\n",
" '__getattribute__',\n",
" '__gt__',\n",
" '__hash__',\n",
" '__init__',\n",
" '__init_subclass__',\n",
" '__le__',\n",
" '__lt__',\n",
" '__module__',\n",
" '__ne__',\n",
" '__new__',\n",
" '__reduce__',\n",
" '__reduce_ex__',\n",
" '__repr__',\n",
" '__setattr__',\n",
" '__sizeof__',\n",
" '__str__',\n",
" '__subclasshook__',\n",
" '__weakref__',\n",
" '_check_callback_called_legality',\n",
" '_check_train_batch_loop_legality',\n",
" '_custom_callbacks',\n",
" '_driver',\n",
" '_evaluate_dataloaders',\n",
" '_fetch_matched_fn_callbacks',\n",
" '_set_num_eval_batch_per_dl',\n",
" '_train_batch_loop',\n",
" '_train_dataloader',\n",
" '_train_step',\n",
" '_train_step_signature_fn',\n",
" 'accumulation_steps',\n",
" 'add_callback_fn',\n",
" 'backward',\n",
" 'batch_idx_in_epoch',\n",
" 'batch_step_fn',\n",
" 'callback_manager',\n",
" 'check_batch_step_fn',\n",
" 'cur_epoch_idx',\n",
" 'data_device',\n",
" 'dataloader',\n",
" 'device',\n",
" 'driver',\n",
" 'driver_name',\n",
" 'epoch_evaluate',\n",
" 'evaluate_batch_step_fn',\n",
" 'evaluate_dataloaders',\n",
" 'evaluate_every',\n",
" 'evaluate_fn',\n",
" 'evaluator',\n",
" 'extract_loss_from_outputs',\n",
" 'fp16',\n",
" 'get_no_sync_context',\n",
" 'global_forward_batches',\n",
" 'has_checked_train_batch_loop',\n",
" 'input_mapping',\n",
" 'kwargs',\n",
" 'larger_better',\n",
" 'load_checkpoint',\n",
" 'load_model',\n",
" 'marker',\n",
" 'metrics',\n",
" 'model',\n",
" 'model_device',\n",
" 'monitor',\n",
" 'move_data_to_device',\n",
" 'n_epochs',\n",
" 'num_batches_per_epoch',\n",
" 'on',\n",
" 'on_after_backward',\n",
" 'on_after_optimizers_step',\n",
" 'on_after_trainer_initialized',\n",
" 'on_after_zero_grad',\n",
" 'on_before_backward',\n",
" 'on_before_optimizers_step',\n",
" 'on_before_zero_grad',\n",
" 'on_evaluate_begin',\n",
" 'on_evaluate_end',\n",
" 'on_exception',\n",
" 'on_fetch_data_begin',\n",
" 'on_fetch_data_end',\n",
" 'on_load_checkpoint',\n",
" 'on_load_model',\n",
" 'on_sanity_check_begin',\n",
" 'on_sanity_check_end',\n",
" 'on_save_checkpoint',\n",
" 'on_save_model',\n",
" 'on_train_batch_begin',\n",
" 'on_train_batch_end',\n",
" 'on_train_begin',\n",
" 'on_train_end',\n",
" 'on_train_epoch_begin',\n",
" 'on_train_epoch_end',\n",
" 'optimizers',\n",
" 'output_mapping',\n",
" 'progress_bar',\n",
" 'run',\n",
" 'run_evaluate',\n",
" 'save_checkpoint',\n",
" 'save_model',\n",
" 'start_batch_idx_in_epoch',\n",
" 'state',\n",
" 'step',\n",
" 'step_evaluate',\n",
" 'total_batches',\n",
" 'train_batch_loop',\n",
" 'train_dataloader',\n",
" 'train_fn',\n",
" 'train_step',\n",
" 'trainer_state',\n",
" 'zero_grad']"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(trainer)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "953533c4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Help on method run in module fastNLP.core.controllers.trainer:\n",
"\n",
"run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n",
" 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n",
" \n",
" 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None并且使用 ``CheckpointCallback``\n",
" 去保存断点重训的文件;\n",
" \n",
" :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n",
" :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n",
" :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n",
" :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n",
" :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``\n",
" 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n",
" 其余状态都是保持初始化时的状态不会改变;\n",
" :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n",
" ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n",
" 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True\n",
" \n",
" .. warning::\n",
" \n",
" 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n",
" ``trainer.cur_epoch_idx == trainer.n_epochs``\n",
" \n",
" 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``\n",
" \n",
" .. note::\n",
" \n",
" 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n",
" 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度在验证整体的训练流程没有错误后再将\n",
" 该值设定为 **-1** 开始真正的训练;\n",
" \n",
" ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n",
" 整体的验证流程是否正确;\n",
" \n",
" ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n",
" 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None`` 进行验证,此时验证的 batch 的\n",
" 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n",
" 进行验证时会验证的 batch 的数量。\n",
" \n",
" 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n",
" 应当为一个很小的正整数,例如 2\n",
" \n",
" .. note::\n",
" \n",
" 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n",
" \n",
" 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n",
" 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n",
" \n",
" fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n",
" \n",
" 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`\n",
" ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n",
" 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``\n",
" \n",
" 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n",
"\n"
]
}
],
"source": [
"help(trainer.run)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bc7cb4a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}