fastNLP/tutorials/fastnlp_tutorial_4.ipynb
2022-06-04 21:15:44 +08:00

2615 lines
111 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "fdd7ff16",
"metadata": {},
"source": [
"# T4. fastNLP 中的预定义模型\n",
"\n",
"  1   fastNLP 中 modules 的介绍\n",
" \n",
"    1.1   modules 模块、models 模块 简介\n",
"\n",
"    1.2   示例一modules 实现 LSTM 分类\n",
"\n",
"  2   fastNLP 中 models 的介绍\n",
" \n",
"    2.1   示例一models 实现 CNN 分类\n",
"\n",
"    2.3   示例二models 实现 BiLSTM 标注"
]
},
{
"cell_type": "markdown",
"id": "d3d65d53",
"metadata": {},
"source": [
"## 1. fastNLP 中 modules 模块的介绍\n",
"\n",
"### 1.1 modules 模块、models 模块 简介\n",
"\n",
"在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n",
"\n",
"    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
"\n",
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
"|:--|:--|:--|\n",
"| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n",
"| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
"| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
"| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
"| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n",
"| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n",
"| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n",
"| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n",
"| `MLP` | 多层感知机模型 | `/modules/torch/decoder/mlp.py` |\n",
"| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n",
"| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
"| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
"| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
"| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n",
"| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |"
]
},
{
"cell_type": "markdown",
"id": "89ffcf07",
"metadata": {},
"source": [
"&emsp; **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n",
"\n",
"&emsp; &emsp; 例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n",
"\n",
"&emsp; &emsp; 基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n",
"\n",
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
"|:--|:--|:--|\n",
"| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n",
"| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n",
"| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n",
"| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n",
"| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n",
"| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n",
"| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
"| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
"| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |"
]
},
{
"cell_type": "markdown",
"id": "61318354",
"metadata": {},
"source": [
"上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n",
"\n",
"&emsp; 同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n",
"\n",
"&emsp; 在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n",
"\n",
"注一:**`SST`****单句情感分类**数据集包含电影评论和对应情感极性1 对应正面情感0 对应负面情感\n",
"\n",
"&emsp; 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
"\n",
"注二:**`CoNLL-2003`****文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n",
"\n",
"&emsp; 语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n",
"\n",
"&emsp; 数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)"
]
},
{
"cell_type": "markdown",
"id": "2a36bbe4",
"metadata": {},
"source": [
"### 1.2 示例一modules 实现 LSTM 分类\n",
"\n",
"\n",
"&emsp; 本示例使用`fastNLP 0.8`中预定义模型`modules`模块,基于`LSTM`模型,实现`SST-2`文本二分类任务\n",
"\n",
"数据使用方面:首先,**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集加载\n",
"\n",
"&emsp; &emsp; 首次下载保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1aa5cf6d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b8bdfdc011d349e38a1aa2aff35b2482",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"sst2data = load_dataset('glue', 'sst2')"
]
},
{
"cell_type": "markdown",
"id": "c476abe7",
"metadata": {},
"source": [
"&emsp; 接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
"\n",
"&emsp; &emsp; **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n",
"\n",
"&emsp; &emsp; &emsp; 并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
"\n",
"&emsp; &emsp; **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n",
"\n",
"&emsp; &emsp; &emsp; 两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "357ea748",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[38;5;2m[i 0604 16:19:46.727257 48 log.cc:351] Load log_sync: 1\u001b[m\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/6000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import DataSet\n",
"\n",
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
"\n",
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
" progress_bar=\"tqdm\")\n",
"dataset.delete_field('sentence')\n",
"dataset.delete_field('label')\n",
"dataset.delete_field('idx')\n",
"\n",
"from fastNLP import Vocabulary\n",
"\n",
"vocab = Vocabulary()\n",
"vocab.from_dataset(dataset, field_name='words')\n",
"vocab.index_dataset(dataset, field_name='words')\n",
"\n",
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
]
},
{
"cell_type": "markdown",
"id": "96380c67",
"metadata": {},
"source": [
"&emsp; 然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b9dd1273",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import prepare_torch_dataloader\n",
"\n",
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
]
},
{
"cell_type": "markdown",
"id": "eb75aaba",
"metadata": {},
"source": [
"模型使用方面,这里使用`Embedding`、`LSTM`、`MLP`等模块搭建模型,方法类似`pytorch`,结构如下所示\n",
"\n",
"```\n",
"ClsByModules(\n",
" (embedding): Embedding(\n",
" (embed): Embedding(8458, 100)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (lstm): LSTM(\n",
" (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n",
" )\n",
" (mlp): MLP(\n",
" (hiddens): ModuleList()\n",
" (output): Linear(in_features=128, out_features=2, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (loss_fn): CrossEntropyLoss()\n",
")\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0b25b25c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from fastNLP.modules.torch import LSTM, MLP\n",
"from fastNLP.embeddings.torch import Embedding\n",
"\n",
"\n",
"class ClsByModules(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
" nn.Module.__init__(self)\n",
"\n",
" self.embedding = Embedding((vocab_size, embedding_dim))\n",
" self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
" self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n",
" \n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, words):\n",
" output = self.embedding(words)\n",
" output, (hidden, cell) = self.lstm(output)\n",
" output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n",
" return output\n",
" \n",
" def train_step(self, words, target):\n",
" pred = self(words)\n",
" return {\"loss\": self.loss_fn(pred, target)}\n",
"\n",
" def evaluate_step(self, words, target):\n",
" pred = self(words)\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": target}"
]
},
{
"cell_type": "markdown",
"id": "4890de5a",
"metadata": {},
"source": [
"&emsp; 接着,初始化模型`model`实例,同时,使用`torch.optim.AdamW`初始化`optimizer`实例"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9dbbf50d",
"metadata": {},
"outputs": [],
"source": [
"model = ClsByModules(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-5)"
]
},
{
"cell_type": "markdown",
"id": "054538f5",
"metadata": {},
"source": [
"&emsp; 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7a93432f",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer, Accuracy\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
" device=0, # 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "31102e0f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[16:20:10] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
"</pre>\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
"</pre>\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.525</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">84.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.54375</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">87.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.55</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">88.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.65</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">104.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.69375</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">111.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.675</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">108.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.66875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">107.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.675</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">108.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.68125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">109.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8bc4bfb2",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "markdown",
"id": "07538876",
"metadata": {},
"source": [
"&emsp; 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1b52eafd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"383"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gc\n",
"\n",
"del model\n",
"del trainer\n",
"del dataset\n",
"del sst2data\n",
"\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"id": "d9443213",
"metadata": {},
"source": [
"## 2. fastNLP 中 models 模块的介绍\n",
"\n",
"### 2.1 示例一models 实现 CNN 分类\n",
"\n",
"&emsp; 本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
"\n",
"数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n",
"\n",
"模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n",
"\n",
"&emsp; 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
"\n",
"&emsp; &emsp; **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n",
"\n",
"&emsp; 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
"\n",
"```\n",
"CNNText(\n",
" (embed): Embedding(\n",
" (embed): Embedding(5194, 100)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (conv_pool): ConvMaxpool(\n",
" (convs): ModuleList(\n",
" (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
" (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
" (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (fc): Linear(in_features=120, out_features=2, bias=True)\n",
")\n",
"```\n",
"\n",
"对应到代码上,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`和`optimizer`实例\n",
"\n",
"&emsp; 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n",
"\n",
"&emsp; &emsp; **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f6e76e2e",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models.torch import CNNText\n",
"\n",
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{
"cell_type": "markdown",
"id": "0cc5ca10",
"metadata": {},
"source": [
"&emsp; 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "50a13ee5",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer, Accuracy\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
" device=0, # 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'acc': Accuracy()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "28903a7d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[16:21:57] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.654444</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">589.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.767778</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">691.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.797778</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">718.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.803333</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">723.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.807778</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">727.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.812222</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">731.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.804444</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">724.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.811111</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">730.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.811111</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">730.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.806667</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">726.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f47a6a35",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "markdown",
"id": "5b5c0446",
"metadata": {},
"source": [
"&emsp; 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e9e70f88",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"344"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gc\n",
"\n",
"del model\n",
"del trainer\n",
"\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"id": "6aec2a19",
"metadata": {},
"source": [
"### 2.2 示例二models 实现 BiLSTM 标注\n",
"\n",
"&emsp; 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
"\n",
"&emsp; &emsp; 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
"\n",
"&emsp; 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
"\n",
"模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
"\n",
"&emsp; 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
"\n",
"```\n",
"BiLSTMCRF(\n",
" (embed): Embedding(7590, 100)\n",
" (lstm): LSTM(\n",
" (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (fc): Linear(in_features=200, out_features=9, bias=True)\n",
" (crf): ConditionalRandomField()\n",
")\n",
"```\n",
"\n",
"数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
"\n",
"&emsp; 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "03e66686",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "593bc03ed5914953ab94268ff2f01710",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"ner2data = load_dataset('conll2003', 'conll2003')"
]
},
{
"cell_type": "markdown",
"id": "fc505631",
"metadata": {},
"source": [
"紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
"\n",
"&emsp; 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n",
"\n",
"此外,**需要定义`NER`标签到标签序号的映射****词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n",
"\n",
"&emsp; 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n",
"\n",
"&emsp; **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n",
"\n",
"&emsp; **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1f88cad4",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/4000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import DataSet\n",
"\n",
"dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n",
"\n",
"dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n",
" progress_bar=\"tqdm\")\n",
"dataset.delete_field('tokens')\n",
"dataset.delete_field('ner_tags')\n",
"dataset.delete_field('pos_tags')\n",
"dataset.delete_field('chunk_tags')\n",
"dataset.delete_field('id')\n",
"\n",
"from fastNLP import Vocabulary\n",
"\n",
"token_vocab = Vocabulary()\n",
"token_vocab.from_dataset(dataset, field_name='words')\n",
"token_vocab.index_dataset(dataset, field_name='words')\n",
"label_vocab = Vocabulary(padding=None, unknown=None)\n",
"label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n",
"\n",
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
]
},
{
"cell_type": "markdown",
"id": "d9889427",
"metadata": {},
"source": [
"然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "7802a072",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import prepare_torch_dataloader\n",
"\n",
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
]
},
{
"cell_type": "markdown",
"id": "2bc7831b",
"metadata": {},
"source": [
"接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n",
"\n",
"&emsp; 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n",
"\n",
"&emsp; &emsp; 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4e12c09f",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models.torch import BiLSTMCRF\n",
"\n",
"model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n",
" num_layers=1, hidden_size=150, dropout=0.2)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=1e-3)"
]
},
{
"cell_type": "markdown",
"id": "bf30608f",
"metadata": {},
"source": [
"最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n",
"\n",
"&emsp; **使用`SpanFPreRecMetric`作为`NER`的评价标准**,详细请参考接下来的`tutorial-5`\n",
"\n",
"&emsp; 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "cbd6c205",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer, SpanFPreRecMetric\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
" device=0, # 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=train_dataloader,\n",
" evaluate_dataloaders=evaluate_dataloader,\n",
" metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "0f8eff34",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[16:23:41] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.169014</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.170732</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.167331</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.361809</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.312139</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.430279</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.525</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.475728</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.585657</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.627306</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.584192</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.677291</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.710937</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.697318</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.7251</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.739563</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.738095</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.741036</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.748491</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.756098</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.741036</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.716763</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.69403</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.741036</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.768293</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.784232</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.752988</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.757692</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.732342</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#F1\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.784861</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n",
" \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n",
" \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "37871d6b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96bae094",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}