update example-12 lxr 220527

This commit is contained in:
lxr-tech 2022-05-27 22:47:11 +08:00
parent 49e8ae2daa
commit 19a48c7101
2 changed files with 241 additions and 631 deletions

View File

@ -233,7 +233,7 @@
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
@ -881,6 +881,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,

View File

@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# E2. 使用 PrefixTuning 完成 SST2 分类"
"# E2. 使用 continuous prompt 完成 SST2 分类"
]
},
{
@ -35,10 +35,12 @@
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"import torch.nn as nn\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"\n",
"import transformers\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoModelForSequenceClassification\n",
@ -69,6 +71,136 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class PromptEncoder(nn.Module):\n",
" def __init__(self, template, hidden_size):\n",
" nn.Module.__init__(self)\n",
" self.template = template\n",
" self.hidden_size = hidden_size\n",
" self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n",
" self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n",
"\n",
" self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n",
" # embed\n",
" self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n",
" # LSTM\n",
" self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n",
" hidden_size=hidden_size // 2,\n",
" num_layers=2, dropout=0.0,\n",
" bidirectional=True, batch_first=True)\n",
" # MLP\n",
" self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, hidden_size))\n",
" print(\"init prompt encoder...\")\n",
"\n",
" def forward(self, device):\n",
" input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n",
" output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n",
" return output_embeds"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n",
" nn.Module.__init__(self)\n",
" self.template = template\n",
" self.num_labels = num_labels\n",
" self.spell_length = sum(template)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" for param in self.back_bone.parameters():\n",
" param.requires_grad = False\n",
" self.embeddings = self.back_bone.get_input_embeddings()\n",
" \n",
" self.hidden_size = self.embeddings.embedding_dim\n",
" self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n",
" self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n",
" self.pad_token_id = self.tokenizer.pad_token_id\n",
" \n",
" self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n",
"\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def get_query(self, query):\n",
" device = query.device\n",
" return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n",
" torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n",
" torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n",
" torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n",
" query, \n",
" torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n",
"\n",
" def forward(self, input_ids):\n",
" input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n",
" attention_mask = input_ids != self.pad_token_id\n",
" \n",
" bz = input_ids.shape[0]\n",
" inputs_embeds = input_ids.clone()\n",
" inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n",
" inputs_embeds = self.embeddings(inputs_embeds)\n",
"\n",
" blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n",
" replace_embeds = self.prompt_encoder(input_ids.device)\n",
" for bidx in range(bz):\n",
" for i in range(self.spell_length):\n",
" inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n",
" \n",
" return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids).logits\n",
" return {\"loss\": self.loss_fn(pred, labels)}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"init prompt encoder...\n"
]
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
@ -77,14 +209,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n",
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "253d79d7a67e4dc88338448b5bcb3fb9",
"model_id": "f82d2ccee863492582f94552654482f9",
"version_major": 2,
"version_minor": 0
},
@ -102,88 +233,57 @@
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
"\n",
"print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"task_to_keys = {\n",
" \"cola\": (\"sentence\", None),\n",
" \"mnli\": (\"premise\", \"hypothesis\"),\n",
" \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
" \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
" \"qnli\": (\"question\", \"sentence\"),\n",
" \"qqp\": (\"question1\", \"question2\"),\n",
" \"rte\": (\"sentence1\", \"sentence2\"),\n",
" \"sst2\": (\"sentence\", None),\n",
" \"stsb\": (\"sentence1\", \"sentence2\"),\n",
" \"wnli\": (\"sentence1\", \"sentence2\"),\n",
"}\n",
"\n",
"sentence1_key, sentence2_key = task_to_keys[task]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence: hide new secretions from the parental units \n"
]
}
],
"source": [
"if sentence2_key is None:\n",
" print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n",
"else:\n",
" print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n",
" print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf324902e7b94ea9be709b979b425c96",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/68 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21eb6203ec6f4592b8cb8530a59eda49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "05b83c4b1a9f44aea805788e1e52db78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00<?, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def preprocess_function(examples):\n",
" if sentence2_key is None:\n",
" return tokenizer(examples[sentence1_key], truncation=True)\n",
" return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n",
" return model.tokenizer(examples['sentence'], truncation=True)\n",
"\n",
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
]
@ -193,58 +293,6 @@
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, num_labels, model_checkpoint):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, input_ids, attention_mask):\n",
" return self.back_bone(input_ids, attention_mask)\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" return {\"loss\": self.loss_fn(pred, labels)}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-5)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class TestDistilBertDataset(Dataset):\n",
" def __init__(self, dataset):\n",
@ -261,7 +309,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -287,7 +335,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -301,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -319,511 +367,15 @@
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# help(model.back_bone.forward)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[21:00:11] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n",
"</pre>\n"
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": [
"\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.903125</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">289.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.871875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">279.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">280.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
"</pre>\n"
],
"text/plain": [
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.890625</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">285.0</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n",
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
@ -856,6 +408,46 @@
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -881,6 +473,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,