2022-05-18 15:41:24 +08:00
|
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "fdd7ff16",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"# T4. fastNLP 中的预定义模型\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  1   fastNLP 中 modules 的介绍\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"    1.1   modules 模块、models 模块 简介\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    1.2   示例一:modules 实现 LSTM 分类\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  2   fastNLP 中 models 的介绍\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
" \n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    2.1   示例一:models 实现 CNN 分类\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    2.3   示例二:models 实现 BiLSTM 标注"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d3d65d53",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## 1. fastNLP 中 modules 模块的介绍\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"### 1.1 modules 模块、models 模块 简介\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
|
|
|
|
|
"|:--|:--|:--|\n",
|
|
|
|
|
"| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n",
|
|
|
|
|
"| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
|
|
|
|
|
"| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
|
|
|
|
|
"| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n",
|
|
|
|
|
"| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n",
|
|
|
|
|
"| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n",
|
|
|
|
|
"| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n",
|
|
|
|
|
"| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n",
|
|
|
|
|
"| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n",
|
|
|
|
|
"| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
|
|
|
|
|
"| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
|
|
|
|
|
"| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n",
|
|
|
|
|
"| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n",
|
|
|
|
|
"| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "89ffcf07",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"    例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
|
|
|
|
|
"|:--|:--|:--|\n",
|
|
|
|
|
"| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n",
|
|
|
|
|
"| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n",
|
|
|
|
|
"| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n",
|
|
|
|
|
"| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n",
|
|
|
|
|
"| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n",
|
|
|
|
|
"| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n",
|
|
|
|
|
"| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
|
|
|
|
|
"| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n",
|
|
|
|
|
"| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "61318354",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "2a36bbe4",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"### 1.2 示例一:modules 实现 LSTM 分类"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "40e66b21",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"# import sys\n",
|
|
|
|
|
"# sys.path.append('..')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# dataset = databundle.get_dataset('train')[:6000]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
|
|
|
|
|
"# progress_bar=\"tqdm\")\n",
|
|
|
|
|
"# dataset.delete_field('sentence')\n",
|
|
|
|
|
"# dataset.delete_field('label')\n",
|
|
|
|
|
"# dataset.delete_field('idx')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# from fastNLP import Vocabulary\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# vocab = Vocabulary()\n",
|
|
|
|
|
"# vocab.from_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"# vocab.index_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "50960476",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"outputs": [],
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"# from fastNLP import prepare_torch_dataloader\n",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
|
|
|
|
"# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
|
2022-05-18 15:41:24 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "0b25b25c",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"# import torch\n",
|
|
|
|
|
"# import torch.nn as nn\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n",
|
|
|
|
|
"# from fastNLP import Embedding, CrossEntropyLoss\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# class ClsByModules(nn.Module):\n",
|
|
|
|
|
"# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
|
|
|
|
|
"# nn.Module.__init__(self)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# self.embedding = Embedding((vocab_size, embedding_dim))\n",
|
|
|
|
|
"# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n",
|
|
|
|
|
"# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"# self.loss_fn = CrossEntropyLoss()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# def forward(self, words):\n",
|
|
|
|
|
"# output = self.embedding(words)\n",
|
|
|
|
|
"# output, (hidden, cell) = self.lstm(output)\n",
|
|
|
|
|
"# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n",
|
|
|
|
|
"# return output\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"# def train_step(self, words, target):\n",
|
|
|
|
|
"# pred = self(words)\n",
|
|
|
|
|
"# return {\"loss\": self.loss_fn(pred, target)}\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# def evaluate_step(self, words, target):\n",
|
|
|
|
|
"# pred = self(words)\n",
|
|
|
|
|
"# pred = torch.max(pred, dim=-1)[1]\n",
|
|
|
|
|
"# return {\"pred\": pred, \"target\": target}"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "9dbbf50d",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# from torch.optim import AdamW\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# optimizers = AdamW(params=model.parameters(), lr=5e-5)"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "7a93432f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"# from fastNLP import Trainer, Accuracy\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"# trainer = Trainer(\n",
|
|
|
|
|
"# model=model,\n",
|
|
|
|
|
"# driver='torch',\n",
|
|
|
|
|
"# device=0, # 'cuda'\n",
|
|
|
|
|
"# n_epochs=10,\n",
|
|
|
|
|
"# optimizers=optimizers,\n",
|
|
|
|
|
"# train_dataloader=train_dataloader,\n",
|
|
|
|
|
"# evaluate_dataloaders=evaluate_dataloader,\n",
|
|
|
|
|
"# metrics={'acc': Accuracy()}\n",
|
|
|
|
|
"# )"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "31102e0f",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"# trainer.run(num_eval_batch_per_dl=10)"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "8bc4bfb2",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"# trainer.evaluator.run()"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "d9443213",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"## 2. fastNLP 中 models 模块的介绍\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"### 2.1 示例一:models 实现 CNN 分类\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"```\n",
|
|
|
|
|
"CNNText(\n",
|
|
|
|
|
" (embed): Embedding(\n",
|
|
|
|
|
" (embed): Embedding(5194, 100)\n",
|
|
|
|
|
" (dropout): Dropout(p=0.0, inplace=False)\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (conv_pool): ConvMaxpool(\n",
|
|
|
|
|
" (convs): ModuleList(\n",
|
|
|
|
|
" (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
|
|
|
|
|
" (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
|
|
|
|
" (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
|
|
" (fc): Linear(in_features=120, out_features=2, bias=True)\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"```\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "1aa5cf6d",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from datasets import load_dataset\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"sst2data = load_dataset('glue', 'sst2')"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "c476abe7",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"    并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"    两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "357ea748",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"import sys\n",
|
|
|
|
|
"sys.path.append('..')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from fastNLP import DataSet\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
|
|
|
|
|
" progress_bar=\"tqdm\")\n",
|
|
|
|
|
"dataset.delete_field('sentence')\n",
|
|
|
|
|
"dataset.delete_field('label')\n",
|
|
|
|
|
"dataset.delete_field('idx')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from fastNLP import Vocabulary\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"vocab = Vocabulary()\n",
|
|
|
|
|
"vocab.from_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"vocab.index_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "96380c67",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "b9dd1273",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP import prepare_torch_dataloader\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
|
|
|
|
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "96941b63",
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "f6e76e2e",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP.models.torch import CNNText\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from torch.optim import AdamW\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "0cc5ca10",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "50a13ee5",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP import Trainer, Accuracy\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"trainer = Trainer(\n",
|
|
|
|
|
" model=model,\n",
|
|
|
|
|
" driver='torch',\n",
|
|
|
|
|
" device=0, # 'cuda'\n",
|
|
|
|
|
" n_epochs=10,\n",
|
|
|
|
|
" optimizers=optimizers,\n",
|
|
|
|
|
" train_dataloader=train_dataloader,\n",
|
|
|
|
|
" evaluate_dataloaders=evaluate_dataloader,\n",
|
|
|
|
|
" metrics={'acc': Accuracy()}\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "28903a7d",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"trainer.run()"
|
2022-06-02 22:43:48 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "f47a6a35",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"trainer.evaluator.run()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "7c811257",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "c1a2e2ca",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"import gc\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"del model\n",
|
|
|
|
|
"del trainer\n",
|
|
|
|
|
"del dataset\n",
|
|
|
|
|
"del sst2data\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"gc.collect()"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "6aec2a19",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"### 2.2 示例二:models 实现 BiLSTM 标注\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"```\n",
|
|
|
|
|
"BiLSTMCRF(\n",
|
|
|
|
|
" (embed): Embedding(7590, 100)\n",
|
|
|
|
|
" (lstm): LSTM(\n",
|
|
|
|
|
" (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
|
|
" (fc): Linear(in_features=200, out_features=9, bias=True)\n",
|
|
|
|
|
" (crf): ConditionalRandomField()\n",
|
|
|
|
|
")\n",
|
|
|
|
|
"```\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "03e66686",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from datasets import load_dataset\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ner2data = load_dataset('conll2003', 'conll2003')"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "fc505631",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "1f88cad4",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"import sys\n",
|
|
|
|
|
"sys.path.append('..')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from fastNLP import DataSet\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n",
|
|
|
|
|
" progress_bar=\"tqdm\")\n",
|
|
|
|
|
"dataset.delete_field('tokens')\n",
|
|
|
|
|
"dataset.delete_field('ner_tags')\n",
|
|
|
|
|
"dataset.delete_field('pos_tags')\n",
|
|
|
|
|
"dataset.delete_field('chunk_tags')\n",
|
|
|
|
|
"dataset.delete_field('id')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from fastNLP import Vocabulary\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"token_vocab = Vocabulary()\n",
|
|
|
|
|
"token_vocab.from_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"token_vocab.index_dataset(dataset, field_name='words')\n",
|
|
|
|
|
"label_vocab = Vocabulary(padding=None, unknown=None)\n",
|
|
|
|
|
"label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "d9889427",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "7802a072",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP import prepare_torch_dataloader\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
|
|
|
|
|
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "2bc7831b",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n",
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"\n",
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"    隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`"
|
2022-06-01 23:15:21 +08:00
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "4e12c09f",
|
|
|
|
|
"metadata": {},
|
2022-06-01 23:15:21 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP.models.torch import BiLSTMCRF\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n",
|
|
|
|
|
" num_layers=1, hidden_size=150, dropout=0.2)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from torch.optim import AdamW\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"optimizers = AdamW(params=model.parameters(), lr=1e-3)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"id": "bf30608f",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  **使用`SpanFPreRecMetric`作为`NER`的评价标准**,详细请参考接下来的`tutorial-5`\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"  同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**"
|
|
|
|
|
]
|
2022-06-01 23:15:21 +08:00
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "cbd6c205",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"from fastNLP import Trainer, SpanFPreRecMetric\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"trainer = Trainer(\n",
|
|
|
|
|
" model=model,\n",
|
|
|
|
|
" driver='torch',\n",
|
|
|
|
|
" device=0, # 'cuda'\n",
|
|
|
|
|
" n_epochs=10,\n",
|
|
|
|
|
" optimizers=optimizers,\n",
|
|
|
|
|
" train_dataloader=train_dataloader,\n",
|
|
|
|
|
" evaluate_dataloaders=evaluate_dataloader,\n",
|
|
|
|
|
" metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "0f8eff34",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"trainer.run(num_eval_batch_per_dl=10)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"execution_count": null,
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"id": "37871d6b",
|
|
|
|
|
"metadata": {},
|
2022-06-04 00:03:40 +08:00
|
|
|
|
"outputs": [],
|
2022-06-02 22:43:48 +08:00
|
|
|
|
"source": [
|
|
|
|
|
"trainer.evaluator.run()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "96bae094",
|
|
|
|
|
"metadata": {},
|
2022-05-18 15:41:24 +08:00
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": []
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "Python 3 (ipykernel)",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
2022-05-31 23:25:01 +08:00
|
|
|
|
"version": "3.7.13"
|
2022-05-18 15:41:24 +08:00
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
}
|