fastNLP/tutorials/fastnlp_tutorial_0.ipynb
2022-05-17 18:04:15 +08:00

811 lines
26 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "aec0fde7",
"metadata": {},
"source": [
"# T0. trainer 和 evaluator 的基本使用\n",
"\n",
"  1   trainer 和 evaluator 的基本关系\n",
" \n",
"    1.1   trainer 和 evaluater 的初始化\n",
"\n",
"    1.2   driver 的含义与使用要求\n",
"\n",
"    1.3   trainer 内部初始化 evaluater\n",
"\n",
"  2   使用 fastNLP 0.8 搭建 argmax 模型\n",
"\n",
"    2.1   trainer_step 和 evaluator_step\n",
"\n",
"    2.2   trainer 和 evaluator 的参数匹配\n",
"\n",
"    2.3   一个实际案例argmax 模型\n",
"\n",
"  3   使用 fastNLP 0.8 训练 argmax 模型\n",
" \n",
"    3.1   trainer 外部初始化的 evaluator\n",
"\n",
"    3.2   trainer 内部初始化的 evaluator "
]
},
{
"cell_type": "markdown",
"id": "09ea669a",
"metadata": {},
"source": [
"## 1. trainer 和 evaluator 的基本关系\n",
"\n",
"### 1.1 trainer 和 evaluator 的初始化\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
"\n",
"  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
"\n",
"在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
"\n",
"  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`\n",
"\n",
"\n",
"```python\n",
"trainer = Trainer(\n",
" model=model, # 模型基于 torch.nn.Module\n",
" train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n",
" optimizers=optimizer, # 优化模块基于 torch.optim.*\n",
"\t...\n",
"\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n",
"\tdevice='cuda', # 使用 GPU0 显卡执行训练\n",
"\t...\n",
")\n",
"...\n",
"evaluator = Evaluator(\n",
" model=model, # 模型基于 torch.nn.Module\n",
" dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n",
" metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n",
" ...\n",
" driver=trainer.driver, # 保持同 trainer 的 driver 一致\n",
"\tdevice=None,\n",
" ...\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "3c11fe1a",
"metadata": {},
"source": [
"### 1.2 driver 的含义与使用要求\n",
"\n",
"在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n",
"\n",
"  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n",
"\n",
"  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n",
"\n",
"注:这里给出一条建议:**在同一脚本中****所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n",
"\n",
"  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n",
"\n",
"  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦"
]
},
{
"cell_type": "markdown",
"id": "2cac4a1a",
"metadata": {},
"source": [
"### 1.3 Trainer 内部初始化 Evaluator\n",
"\n",
"在`fastNLP 0.8`中,如果在**初始化`Trainer`时****传入参数`evaluator_dataloaders`和`metrics`**\n",
"\n",
"  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n",
"\n",
"```python\n",
"trainer = Trainer(\n",
" model=model,\n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
"\t...\n",
"\tdriver=\"torch\",\n",
"\tdevice='cuda',\n",
"\t...\n",
" evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n",
" metrics={'acc': Accuracy()}, # 传入参数 metrics\n",
"\t...\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "0c9c7dda",
"metadata": {},
"source": [
"## 2. argmax 模型的搭建实例"
]
},
{
"cell_type": "markdown",
"id": "524ac200",
"metadata": {},
"source": [
"### 2.1 trainer_step 和 evaluator_step\n",
"\n",
"在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n",
"\n",
"  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n",
"\n",
"```python\n",
"class Model(torch.nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.loss_fn = torch.nn.CrossEntropyLoss()\n",
" pass\n",
"\n",
" def forward(self, x):\n",
" pass\n",
"\n",
" def train_step(self, x, y):\n",
" pred = self(x)\n",
" return {\"loss\": self.loss_fn(pred, y)}\n",
"\n",
" def evaluate_step(self, x, y):\n",
" pred = self(x)\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": y}\n",
"```\n",
"***\n",
"在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n",
"\n",
"  由于,在`Trainer`训练时,**`Trainer`通过参数`train_fn`对应的模型方法获得当前数据批次的损失值**\n",
"\n",
"  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n",
"\n",
"    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n",
"\n",
"注:在`fastNLP 0.8`中,**`Trainer`要求模型通过`train_step`来返回一个字典****满足如`{\"loss\": loss}`的形式**\n",
"\n",
"  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换详见trainer的详细讲解待补充\n",
"\n",
"同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n",
"\n",
"  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n",
"\n",
"  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n",
"\n",
"  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n",
"\n",
"<img src=\"./figures/T0-fig-training-structure.png\" width=\"68%\" height=\"68%\" align=\"center\"></img>"
]
},
{
"cell_type": "markdown",
"id": "fb3272eb",
"metadata": {},
"source": [
"### 2.2 trainer 和 evaluator 的参数匹配\n",
"\n",
"在`fastNLP 0.8`中,参数匹配涉及到两个方面,分别是在\n",
"\n",
"&emsp; 一方面,**在模型的前向传播中****`dataloader`向`train_step`或`evaluate_step`函数传递`batch`**\n",
"\n",
"&emsp; 另方面,**在模型的评测过程中****`evaluate_dataloader`向`metric`的`update`函数传递`batch`**\n",
"\n",
"对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n",
"\n",
"&emsp; &emsp; **`fastNLP 0.8`要求`dataloader`生成的每个`batch`****满足如`{\"x\": x, \"y\": y}`的形式**\n",
"\n",
"&emsp; 同时,`fastNLP 0.8`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n",
"\n",
"&emsp; &emsp; **字典形式的定义****对应在`Dataset`定义的`__getitem__`方法中**,例如下方的`ArgMaxDatset`\n",
"\n",
"&emsp; 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n",
"\n",
"&emsp; &emsp; `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n",
"\n",
"```python\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, x, y):\n",
" self.x = x\n",
" self.y = y\n",
"\n",
" def __len__(self):\n",
" return len(self.x)\n",
"\n",
" def __getitem__(self, item):\n",
" return {\"x\": self.x[item], \"y\": self.y[item]}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "f5f1a6aa",
"metadata": {},
"source": [
"对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n",
"\n",
"&emsp; &emsp; **`update`函数****针对一个`batch`的预测结果**,计算其累计的评价指标\n",
"\n",
"&emsp; &emsp; **`get_metric`函数****统计`update`函数累计的评价指标**,来计算最终的评价结果\n",
"\n",
"&emsp; 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n",
"\n",
"&emsp; &emsp; 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n",
"\n",
"&emsp; 在此基础上,**`fastNLP 0.8`要求`evaluate_dataloader`生成的每个`batch`传递给对应的`metric`**\n",
"\n",
"&emsp; &emsp; **以`{\"pred\": y_pred, \"target\": y_true}`的形式**,对应其`update`函数的函数签名\n",
"\n",
"<img src=\"./figures/T0-fig-parameter-matching.png\" width=\"75%\" height=\"75%\" align=\"center\"></img>"
]
},
{
"cell_type": "markdown",
"id": "f62b7bb1",
"metadata": {},
"source": [
"### 2.3 一个实际案例argmax 模型\n",
"\n",
"下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n",
"\n",
"&emsp; 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5314482b",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class ArgMaxModel(nn.Module):\n",
" def __init__(self, num_labels, feature_dimension):\n",
" super(ArgMaxModel, self).__init__()\n",
" self.num_labels = num_labels\n",
"\n",
" self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n",
" self.ac1 = nn.ReLU()\n",
" self.linear2 = nn.Linear(in_features=10, out_features=10)\n",
" self.ac2 = nn.ReLU()\n",
" self.output = nn.Linear(in_features=10, out_features=num_labels)\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, x):\n",
" pred = self.ac1(self.linear1(x))\n",
" pred = self.ac2(self.linear2(pred))\n",
" pred = self.output(pred)\n",
" return pred\n",
"\n",
" def train_step(self, x, y):\n",
" pred = self(x)\n",
" return {\"loss\": self.loss_fn(pred, y)}\n",
"\n",
" def evaluate_step(self, x, y):\n",
" pred = self(x)\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": y}"
]
},
{
"cell_type": "markdown",
"id": "71f3fa6b",
"metadata": {},
"source": [
"&emsp; 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n",
"\n",
"&emsp; &emsp; 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n",
"\n",
"&emsp; &emsp; 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fe612e61",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"class ArgMaxDataset(Dataset):\n",
" def __init__(self, feature_dimension, data_num=1000, seed=0):\n",
" self.num_labels = feature_dimension\n",
" self.feature_dimension = feature_dimension\n",
" self.data_num = data_num\n",
" self.seed = seed\n",
"\n",
" g = torch.Generator()\n",
" g.manual_seed(1000)\n",
" self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n",
" self.y = torch.max(self.x, dim=-1)[1]\n",
"\n",
" def __len__(self):\n",
" return self.data_num\n",
"\n",
" def __getitem__(self, item):\n",
" return {\"x\": self.x[item], \"y\": self.y[item]}"
]
},
{
"cell_type": "markdown",
"id": "2cb96332",
"metadata": {},
"source": [
"&emsp; 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n",
"\n",
"&emsp; &emsp; 再根据`ArgMaxDataset`类初始化两个数据集实例分别用来模型测试和模型评测数据量各1000笔"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "76172ef8",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"model = ArgMaxModel(num_labels=10, feature_dimension=10)\n",
"\n",
"train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n",
"evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)"
]
},
{
"cell_type": "markdown",
"id": "4e7d25ee",
"metadata": {},
"source": [
"&emsp; 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块批量大小同为8分别用于训练和测评"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "363b5b09",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
"evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)"
]
},
{
"cell_type": "markdown",
"id": "c8d4443f",
"metadata": {},
"source": [
"&emsp; 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc28a2d9",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"from torch.optim import SGD\n",
"\n",
"optimizer = SGD(model.parameters(), lr=0.001)"
]
},
{
"cell_type": "markdown",
"id": "eb8ca6cf",
"metadata": {},
"source": [
"## 3. 使用 fastNLP 0.8 训练 argmax 模型\n",
"\n",
"### 3.1 trainer 外部初始化的 evaluator"
]
},
{
"cell_type": "markdown",
"id": "55145553",
"metadata": {},
"source": [
"通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n",
"\n",
"&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n",
"\n",
"&emsp; 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
"\n",
"&emsp; &emsp; 但对于`\"auto\"`和`\"rich\"`格式在notebook中进度条在训练结束后会被丢弃\n",
"\n",
"&emsp; 通过`n_epochs`设定优化迭代轮数默认为20全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b51b7a2d",
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from fastNLP import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver=\"torch\",\n",
" device='cuda',\n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
" n_epochs=10, # 设定迭代轮数 \n",
" progress_bar=\"auto\" # 设定进度条格式\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6e202d6e",
"metadata": {},
"source": [
"通过使用`Trainer`类的`run`函数,进行训练\n",
"\n",
"&emsp; 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n",
"\n",
"&emsp; 此外,可以通过`inspect.getfullargspec(trainer.run)`查询`run`函数的全部参数列表"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ba047ead",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run()"
]
},
{
"cell_type": "markdown",
"id": "c16c5fa4",
"metadata": {},
"source": [
"通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n",
"\n",
"&emsp; 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n",
"\n",
"&emsp; 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n",
"\n",
"&emsp; 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c6b6b36",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"from fastNLP import Evaluator\n",
"from fastNLP.core.metrics import Accuracy\n",
"\n",
"evaluator = Evaluator(\n",
" model=model,\n",
" driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n",
" device=None,\n",
" dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n",
")"
]
},
{
"cell_type": "markdown",
"id": "8157bb9b",
"metadata": {},
"source": [
"通过使用`Evaluator`类的`run`函数,进行训练\n",
"\n",
"&emsp; 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
"\n",
"&emsp; 最终,输出形如`{'acc#acc': acc}`的字典在notebook中进度条在评测结束后会被丢弃"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f7cb0165",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.41</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41.0</span><span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.41\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m41.0\u001b[0m\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.41, 'total#acc': 100.0, 'correct#acc': 41.0}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluator.run()"
]
},
{
"cell_type": "markdown",
"id": "dd9f68fa",
"metadata": {},
"source": [
"### 3.2 trainer 内部初始化的 evaluator \n",
"\n",
"通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
"\n",
"&emsp; 通过`progress_bar`同时设定训练和评估进度条格式在notebook中在进度条训练结束后会被丢弃\n",
"\n",
"&emsp; **通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n",
"\n",
"&emsp; &emsp; **为负数时****表示每隔几个`epoch`评估一次****为正数时****则表示每隔几个`batch`评估一次**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "183c7d19",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()},\n",
" optimizers=optimizer,\n",
" n_epochs=10, \n",
" evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n",
")"
]
},
{
"cell_type": "markdown",
"id": "714cc404",
"metadata": {},
"source": [
"通过使用`Trainer`类的`run`函数,进行训练\n",
"\n",
"&emsp; 还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测默认为2\n",
"\n",
"&emsp; 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此试探性评测"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2e4daa2c",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c4e9c619",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.46, 'total#acc': 100.0, 'correct#acc': 46.0}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db784d5b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}