mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
1258 lines
54 KiB
Plaintext
1258 lines
54 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"  从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n",
|
||
"\n",
|
||
"  每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在自然语言处理任务上的应用实例"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n",
|
||
"\n",
|
||
"  1   基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n",
|
||
"\n",
|
||
"  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n",
|
||
"\n",
|
||
"  3   模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"4.18.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"from torch.optim import AdamW\n",
|
||
"from torch.utils.data import DataLoader, Dataset\n",
|
||
"\n",
|
||
"import transformers\n",
|
||
"from transformers import AutoTokenizer\n",
|
||
"from transformers import AutoModelForSequenceClassification\n",
|
||
"\n",
|
||
"import sys\n",
|
||
"sys.path.append('..')\n",
|
||
"\n",
|
||
"import fastNLP\n",
|
||
"from fastNLP import Trainer\n",
|
||
"from fastNLP import Accuracy\n",
|
||
"\n",
|
||
"print(transformers.__version__)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n",
|
||
"\n",
|
||
"  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n",
|
||
"\n",
|
||
"    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n",
|
||
"\n",
|
||
"**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n",
|
||
"\n",
|
||
"  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n",
|
||
"\n",
|
||
"    **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST-2`**,文本分类任务,预测单句情感二分类\n",
|
||
"\n",
|
||
"    **`MRPC`**,句对分类任务,预测句对语义一致性;**`STS-B`**,相似度打分任务,预测句对语义相似度回归\n",
|
||
"\n",
|
||
"    **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n",
|
||
"\n",
|
||
"    **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n",
|
||
"\n",
|
||
"  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n",
|
||
"\n",
|
||
"    此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n",
|
||
"\n",
|
||
"task = 'sst2'"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<img src=\"./figures/E1-fig-glue-benchmark.png\" width=\"70%\" height=\"70%\" align=\"center\"></img>\n",
|
||
"\n",
|
||
"**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n",
|
||
"\n",
|
||
"  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n",
|
||
"\n",
|
||
"  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
|
||
"\n",
|
||
"对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n",
|
||
"\n",
|
||
"  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"scrolled": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "c5915debacf9443986b5b3b34870b303",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/3 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from datasets import load_dataset\n",
|
||
"\n",
|
||
"dataset = load_dataset('glue', task)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"  加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Sentence: hide new secretions from the parental units \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"task_to_keys = {\n",
|
||
" 'cola': ('sentence', None),\n",
|
||
" 'mnli': ('premise', 'hypothesis'),\n",
|
||
" 'mnli': ('premise', 'hypothesis'),\n",
|
||
" 'mrpc': ('sentence1', 'sentence2'),\n",
|
||
" 'qnli': ('question', 'sentence'),\n",
|
||
" 'qqp': ('question1', 'question2'),\n",
|
||
" 'rte': ('sentence1', 'sentence2'),\n",
|
||
" 'sst2': ('sentence', None),\n",
|
||
" 'stsb': ('sentence1', 'sentence2'),\n",
|
||
" 'wnli': ('sentence1', 'sentence2'),\n",
|
||
"}\n",
|
||
"\n",
|
||
"sentence1_key, sentence2_key = task_to_keys[task]\n",
|
||
"\n",
|
||
"if sentence2_key is None:\n",
|
||
" print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
|
||
"else:\n",
|
||
" print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
|
||
" print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n",
|
||
"\n",
|
||
"  接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n",
|
||
"\n",
|
||
"    定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n",
|
||
"\n",
|
||
"此处的`tokenizer`和`SequenceClassificationModel`都是基于**`distilbert-base-uncased`模型**\n",
|
||
"\n",
|
||
"  即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n",
|
||
"\n",
|
||
"  包含**1个编码层**、**6个自注意力层**,**参数量`66M`**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n",
|
||
"\n",
|
||
"首先,通过从`transformers`库中导入**`AutoTokenizer`模块**,**使用`from_pretrained`函数初始化**\n",
|
||
"\n",
|
||
"  此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n",
|
||
"\n",
|
||
"  需要注意的是,处理后返回的两个键值,**`'input_ids'`**表示原始文本对应的词素编号序列\n",
|
||
"\n",
|
||
"    **`'attention_mask'`**表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model_checkpoint = 'distilbert-base-uncased'\n",
|
||
"\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
|
||
"\n",
|
||
"print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"接着,定义预处理函数,**通过`dataset.map`方法**,**将数据集中的文本**,**替换为词素编号序列**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n",
|
||
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n",
|
||
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def preprocess_function(examples):\n",
|
||
" if sentence2_key is None:\n",
|
||
" return tokenizer(examples[sentence1_key], truncation=True)\n",
|
||
" return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n",
|
||
"\n",
|
||
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n",
|
||
"\n",
|
||
"  其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n",
|
||
"\n",
|
||
"  例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n",
|
||
"\n",
|
||
"    `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SeqClsDataset(Dataset):\n",
|
||
" def __init__(self, dataset):\n",
|
||
" Dataset.__init__(self)\n",
|
||
" self.dataset = dataset\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" return len(self.dataset)\n",
|
||
"\n",
|
||
" def __getitem__(self, item):\n",
|
||
" item = self.dataset[item]\n",
|
||
" return item['input_ids'], item['attention_mask'], [item['label']] "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"再然后,**定义校对函数`collate_fn`对齐同个`batch`内的每笔数据**,需要注意的是该函数的\n",
|
||
"\n",
|
||
"  **返回值必须是字典**,**键值必须同待训练模型的`train_step`和`evaluate_step`函数的参数**\n",
|
||
"\n",
|
||
"  **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v0.8`的第一条**参数匹配**机制"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def collate_fn(batch):\n",
|
||
" input_ids, atten_mask, labels = [], [], []\n",
|
||
" max_length = [0] * 3\n",
|
||
" for each_item in batch:\n",
|
||
" input_ids.append(each_item[0])\n",
|
||
" max_length[0] = max(max_length[0], len(each_item[0]))\n",
|
||
" atten_mask.append(each_item[1])\n",
|
||
" max_length[1] = max(max_length[1], len(each_item[1]))\n",
|
||
" labels.append(each_item[2])\n",
|
||
" max_length[2] = max(max_length[2], len(each_item[2]))\n",
|
||
"\n",
|
||
" for i in range(3):\n",
|
||
" each = (input_ids, atten_mask, labels)[i]\n",
|
||
" for item in each:\n",
|
||
" item.extend([0] * (max_length[i] - len(item)))\n",
|
||
" return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
|
||
" 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
|
||
" 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset_train = SeqClsDataset(encoded_dataset['train'])\n",
|
||
"dataloader_train = DataLoader(dataset=dataset_train, \n",
|
||
" batch_size=32, shuffle=True, collate_fn=collate_fn)\n",
|
||
"dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n",
|
||
"dataloader_valid = DataLoader(dataset=dataset_valid, \n",
|
||
" batch_size=32, shuffle=False, collate_fn=collate_fn)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n",
|
||
"\n",
|
||
"  最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n",
|
||
"\n",
|
||
"    初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n",
|
||
"\n",
|
||
"此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n",
|
||
"\n",
|
||
"  导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n",
|
||
"\n",
|
||
"需要注意的是**`AutoModelForSequenceClassification`模块的输入参数和输出结构**\n",
|
||
"\n",
|
||
"  一方面,可以**通过输入标签值`labels`**,**使用模块内的损失函数计算损失`loss`**\n",
|
||
"\n",
|
||
"    并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n",
|
||
"\n",
|
||
"  另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率`logits`**\n",
|
||
"\n",
|
||
"    基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n",
|
||
"\n",
|
||
"    同样需要注意,函数的返回值体现了`fastNLP v0.8`的第二条**参数匹配**机制"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class SeqClsModel(nn.Module):\n",
|
||
" def __init__(self, num_labels, model_checkpoint):\n",
|
||
" nn.Module.__init__(self)\n",
|
||
" self.num_labels = num_labels\n",
|
||
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
|
||
" num_labels=num_labels)\n",
|
||
"\n",
|
||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||
" output = self.back_bone(input_ids=input_ids, \n",
|
||
" attention_mask=attention_mask, labels=labels)\n",
|
||
" return output\n",
|
||
"\n",
|
||
" def train_step(self, input_ids, attention_mask, labels):\n",
|
||
" loss = self(input_ids, attention_mask, labels).loss\n",
|
||
" return {'loss': loss}\n",
|
||
"\n",
|
||
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
|
||
" pred = self(input_ids, attention_mask, labels).logits\n",
|
||
" pred = torch.max(pred, dim=-1)[1]\n",
|
||
" return {'pred': pred, 'target': labels}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n",
|
||
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
|
||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n",
|
||
"\n",
|
||
"model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
|
||
"\n",
|
||
"optimizers = AdamW(params=model.parameters(), lr=5e-5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"trainer = Trainer(\n",
|
||
" model=model,\n",
|
||
" driver='torch',\n",
|
||
" device=0, # 'cuda'\n",
|
||
" n_epochs=10,\n",
|
||
" optimizers=optimizers,\n",
|
||
" train_dataloader=dataloader_train,\n",
|
||
" evaluate_dataloaders=dataloader_valid,\n",
|
||
" metrics={'acc': Accuracy()}\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n",
|
||
"\n",
|
||
"  `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:12:45] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Output()"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||
],
|
||
"text/plain": []
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Output()"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">288.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.88125</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">282.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.875</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">280.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.865625</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">277.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
|
||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n",
|
||
"<span style=\"font-weight: bold\">}</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m{\u001b[0m\n",
|
||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
|
||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
|
||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
|
||
"\u001b[1m}\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||
],
|
||
"text/plain": []
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"trainer.run(num_eval_batch_per_dl=10)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Output()"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
|
||
],
|
||
"text/plain": []
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"trainer.evaluator.run()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### 附:`DistilBertForSequenceClassification`模块结构\n",
|
||
"\n",
|
||
"```\n",
|
||
"<bound method DistilBertForSequenceClassification.forward of DistilBertForSequenceClassification(\n",
|
||
" (distilbert): DistilBertModel(\n",
|
||
" (embeddings): Embeddings(\n",
|
||
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
||
" (position_embeddings): Embedding(512, 768)\n",
|
||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" )\n",
|
||
" (transformer): Transformer(\n",
|
||
" (layer): ModuleList(\n",
|
||
" (0): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" (1): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" (2): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" (3): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" (4): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" (5): TransformerBlock(\n",
|
||
" (attention): MultiHeadSelfAttention(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" )\n",
|
||
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" (ffn): FFN(\n",
|
||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
||
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
||
" (activation): GELUActivation()\n",
|
||
" )\n",
|
||
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n",
|
||
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
||
" (dropout): Dropout(p=0.2, inplace=False)\n",
|
||
")>\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.13"
|
||
},
|
||
"pycharm": {
|
||
"stem_cell": {
|
||
"cell_type": "raw",
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"source": []
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
}
|