{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ " 从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", "\n", " 每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在自然语言处理任务上的应用实例" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n", "\n", " 1 基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n", "\n", " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", "\n", " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "4.18.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "import sys\n", "sys.path.append('..')\n", "\n", "import fastNLP\n", "from fastNLP import Trainer\n", "from fastNLP import Accuracy\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", "\n", " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", "\n", " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", "\n", "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", "\n", " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", "\n", " **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST-2`**,文本分类任务,预测单句情感二分类\n", "\n", " **`MRPC`**,句对分类任务,预测句对语义一致性;**`STS-B`**,相似度打分任务,预测句对语义相似度回归\n", "\n", " **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", "\n", " **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", "\n", " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", "\n", " 此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", "\n", "task = 'sst2'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n", "\n", " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", "\n", " 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", "\n", "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", "\n", " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c5915debacf9443986b5b3b34870b303", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset('glue', task)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " 加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sentence: hide new secretions from the parental units \n" ] } ], "source": [ "task_to_keys = {\n", " 'cola': ('sentence', None),\n", " 'mnli': ('premise', 'hypothesis'),\n", " 'mnli': ('premise', 'hypothesis'),\n", " 'mrpc': ('sentence1', 'sentence2'),\n", " 'qnli': ('question', 'sentence'),\n", " 'qqp': ('question1', 'question2'),\n", " 'rte': ('sentence1', 'sentence2'),\n", " 'sst2': ('sentence', None),\n", " 'stsb': ('sentence1', 'sentence2'),\n", " 'wnli': ('sentence1', 'sentence2'),\n", "}\n", "\n", "sentence1_key, sentence2_key = task_to_keys[task]\n", "\n", "if sentence2_key is None:\n", " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", "else:\n", " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n", "\n", " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n", "\n", " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n", "\n", "此处的`tokenizer`和`SequenceClassificationModel`都是基于**`distilbert-base-uncased`模型**\n", "\n", " 即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n", "\n", " 包含**1个编码层**、**6个自注意力层**,**参数量`66M`**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n", "\n", "首先,通过从`transformers`库中导入**`AutoTokenizer`模块**,**使用`from_pretrained`函数初始化**\n", "\n", " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n", "\n", " 需要注意的是,处理后返回的两个键值,**`'input_ids'`**表示原始文本对应的词素编号序列\n", "\n", " **`'attention_mask'`**表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ "model_checkpoint = 'distilbert-base-uncased'\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", "\n", "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着,定义预处理函数,**通过`dataset.map`方法**,**将数据集中的文本**,**替换为词素编号序列**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" ] } ], "source": [ "def preprocess_function(examples):\n", " if sentence2_key is None:\n", " return tokenizer(examples[sentence1_key], truncation=True)\n", " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", "\n", "encoded_dataset = dataset.map(preprocess_function, batched=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n", "\n", " 其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", "\n", " 例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", "\n", " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class SeqClsDataset(Dataset):\n", " def __init__(self, dataset):\n", " Dataset.__init__(self)\n", " self.dataset = dataset\n", "\n", " def __len__(self):\n", " return len(self.dataset)\n", "\n", " def __getitem__(self, item):\n", " item = self.dataset[item]\n", " return item['input_ids'], item['attention_mask'], [item['label']] " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "再然后,**定义校对函数`collate_fn`对齐同个`batch`内的每笔数据**,需要注意的是该函数的\n", "\n", " **返回值必须是字典**,**键值必须同待训练模型的`train_step`和`evaluate_step`函数的参数**\n", "\n", " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v0.8`的第一条**参数匹配**机制" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def collate_fn(batch):\n", " input_ids, atten_mask, labels = [], [], []\n", " max_length = [0] * 3\n", " for each_item in batch:\n", " input_ids.append(each_item[0])\n", " max_length[0] = max(max_length[0], len(each_item[0]))\n", " atten_mask.append(each_item[1])\n", " max_length[1] = max(max_length[1], len(each_item[1]))\n", " labels.append(each_item[2])\n", " max_length[2] = max(max_length[2], len(each_item[2]))\n", "\n", " for i in range(3):\n", " each = (input_ids, atten_mask, labels)[i]\n", " for item in each:\n", " item.extend([0] * (max_length[i] - len(item)))\n", " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", "dataloader_train = DataLoader(dataset=dataset_train, \n", " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", "dataloader_valid = DataLoader(dataset=dataset_valid, \n", " batch_size=32, shuffle=False, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n", "\n", " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n", "\n", " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n", "\n", "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n", "\n", " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n", "\n", "需要注意的是**`AutoModelForSequenceClassification`模块的输入参数和输出结构**\n", "\n", " 一方面,可以**通过输入标签值`labels`**,**使用模块内的损失函数计算损失`loss`**\n", "\n", " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n", "\n", " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率`logits`**\n", "\n", " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n", "\n", " 同样需要注意,函数的返回值体现了`fastNLP v0.8`的第二条**参数匹配**机制" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class SeqClsModel(nn.Module):\n", " def __init__(self, num_labels, model_checkpoint):\n", " nn.Module.__init__(self)\n", " self.num_labels = num_labels\n", " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", " num_labels=num_labels)\n", "\n", " def forward(self, input_ids, attention_mask, labels=None):\n", " output = self.back_bone(input_ids=input_ids, \n", " attention_mask=attention_mask, labels=labels)\n", " return output\n", "\n", " def train_step(self, input_ids, attention_mask, labels):\n", " loss = self(input_ids, attention_mask, labels).loss\n", " return {'loss': loss}\n", "\n", " def evaluate_step(self, input_ids, attention_mask, labels):\n", " pred = self(input_ids, attention_mask, labels).logits\n", " pred = torch.max(pred, dim=-1)[1]\n", " return {'pred': pred, 'target': labels}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n", "\n", "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", "\n", "optimizers = AdamW(params=model.parameters(), lr=5e-5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", " device=0, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=dataloader_train,\n", " evaluate_dataloaders=dataloader_valid,\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n", "\n", " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", "\n" ], "text/plain": [ "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.878125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 281.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.9,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 288.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.8875,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 284.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.88125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 282.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.875,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 280.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.865625,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 277.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.884375,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 283.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"acc#acc\": 0.878125,\n", " \"total#acc\": 320.0,\n", " \"correct#acc\": 281.0\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 附:`DistilBertForSequenceClassification`模块结构\n", "\n", "```\n", "