fastNLP/tutorials/tutorial_7_metrics.ipynb
2020-02-28 00:44:15 +08:00

1207 lines
33 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用Metric快速评测你的模型\n",
"\n",
"和上一篇教程一样的实验准备代码"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import SST2Pipe\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"from fastNLP.models import CNNText\n",
"import torch\n",
"\n",
"databundle = SST2Pipe().process_from_file()\n",
"vocab = databundle.get_vocab('words')\n",
"train_data = databundle.get_dataset('train')[:5000]\n",
"train_data, test_data = train_data.split(0.015)\n",
"dev_data = databundle.get_dataset('dev')\n",
"\n",
"model = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n",
"loss = CrossEntropyLoss()\n",
"metric = AccuracyMetric()\n",
"device = 0 if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"进行训练时fastNLP提供了各种各样的 metrics 。 如前面的教程中所介绍AccuracyMetric 类的对象被直接传到 Trainer 中用于训练"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-37-08\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.28 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccuracyMetric: acc=0.747706\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.17 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccuracyMetric: acc=0.745413\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.19 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccuracyMetric: acc=0.74656\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.15 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccuracyMetric: acc=0.762615\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccuracyMetric: acc=0.736239\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.16 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccuracyMetric: acc=0.761468\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccuracyMetric: acc=0.727064\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.21 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccuracyMetric: acc=0.731651\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.52 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccuracyMetric: acc=0.752294\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccuracyMetric: acc=0.760321\n",
"\n",
"\r\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccuracyMetric: acc=0.762615\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.762615}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 32.63}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=metric)\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"除了 AccuracyMetric 之外SpanFPreRecMetric 也是一种非常见的评价指标, 例如在序列标注问题中常以span的方式计算 F-measure, precision, recall。\n",
"\n",
"另外fastNLP 还实现了用于抽取式QA如SQuAD的metric ExtractiveQAMetric。 用户可以参考下面这个表格。\n",
"\n",
"| 名称 | 介绍 |\n",
"| -------------------- | ------------------------------------------------- |\n",
"| `MetricBase` | 自定义metrics需继承的基类 |\n",
"| `AccuracyMetric` | 简单的正确率metric |\n",
"| `SpanFPreRecMetric` | 同时计算 F-measure, precision, recall 值的 metric |\n",
"| `ExtractiveQAMetric` | 用于抽取式QA任务 的metric |\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义自己的metrics\n",
"\n",
"在定义自己的metrics类时需继承 fastNLP 的 MetricBase, 并覆盖写入 evaluate 和 get_metric 方法。\n",
"\n",
"- evaluate(xxx) 中传入一个批次的数据,将针对一个批次的预测结果做评价指标的累计\n",
"\n",
"- get_metric(xxx) 当所有数据处理完毕时调用该方法,它将根据 evaluate函数累计的评价指标统计量来计算最终的评价结果\n",
"\n",
"以分类问题中Accuracy计算为例假设model的forward返回dict中包含 pred 这个key, 并且该key需要用于Accuracy:\n",
"\n",
"```python\n",
"class Model(nn.Module):\n",
" def __init__(xxx):\n",
" # do something\n",
" def forward(self, xxx):\n",
" # do something\n",
" return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Version 1\n",
"\n",
"假设dataset中 `target` 这个 field 是需要预测的值,并且该 field 被设置为了 target 对应的 `AccMetric` 可以按如下的定义"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import MetricBase\n",
"\n",
"class AccMetric(MetricBase):\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 根据你的情况自定义指标\n",
" self.total = 0\n",
" self.acc_count = 0\n",
"\n",
" # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致不然找不到对应的value\n",
" # pred, target 的参数是 fastNLP 的默认配置\n",
" def evaluate(self, pred, target):\n",
" # dev或test时每个batch结束会调用一次该方法需要实现如何根据每个batch累加metric\n",
" self.total += target.size(0)\n",
" self.acc_count += target.eq(pred).sum().item()\n",
"\n",
" def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
" acc = self.acc_count/self.total\n",
" if reset: # 是否清零以便重新计算\n",
" self.acc_count = 0\n",
" self.total = 0\n",
" return {'acc': acc}\n",
" # 需要返回一个dictkey为该metric的名称该名称会显示到Trainer的progress bar中"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-37-41\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.27 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.7431192660550459\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7522935779816514\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.51 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.7477064220183486\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7442660550458715\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.5 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7362385321100917\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7293577981651376\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.33 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.7190366972477065\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7419724770642202\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.34 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7350917431192661\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.18 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.6846330275229358\n",
"\n",
"\r\n",
"In Epoch:2/Step:308, got best dev performance:\n",
"AccMetric: acc=0.7522935779816514\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7522935779816514}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 308,\n",
" 'seconds': 42.7}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Version 2\n",
"\n",
"如果需要复用 metric比如下一次使用 `AccMetric` 时dataset中目标field不叫 `target` 而叫 `y` 或者model的输出不是 `pred`\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class AccMetric(MetricBase):\n",
" def __init__(self, pred=None, target=None):\n",
" \"\"\"\n",
" 假设在另一场景使用时目标field叫ymodel给出的key为pred_y。则只需要在初始化AccMetric时\n",
" acc_metric = AccMetric(pred='pred_y', target='y')即可。\n",
" 当初始化为acc_metric = AccMetric() 时fastNLP会直接使用 'pred', 'target' 作为key去索取对应的的值\n",
" \"\"\"\n",
"\n",
" super().__init__()\n",
"\n",
" # 如果没有注册该则效果与 Version 1 就是一样的\n",
" self._init_param_map(pred=pred, target=target) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可\n",
"\n",
" # 根据你的情况自定义指标\n",
" self.total = 0\n",
" self.acc_count = 0\n",
"\n",
" # evaluate的参数需要和DataSet 中 field 名以及模型输出的结果 field 名一致不然找不到对应的value\n",
" # pred, target 的参数是 fastNLP 的默认配置\n",
" def evaluate(self, pred, target):\n",
" # dev或test时每个batch结束会调用一次该方法需要实现如何根据每个batch累加metric\n",
" self.total += target.size(0)\n",
" self.acc_count += target.eq(pred).sum().item()\n",
"\n",
" def get_metric(self, reset=True): # 在这里定义如何计算metric\n",
" acc = self.acc_count/self.total\n",
" if reset: # 是否清零以便重新计算\n",
" self.acc_count = 0\n",
" self.total = 0\n",
" return {'acc': acc}\n",
" # 需要返回一个dictkey为该metric的名称该名称会显示到Trainer的progress bar中"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2020-02-28-00-38-24\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.32 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:154/1540: \n",
"\r",
"AccMetric: acc=0.7511467889908257\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.29 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:308/1540: \n",
"\r",
"AccMetric: acc=0.7454128440366973\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:462/1540: \n",
"\r",
"AccMetric: acc=0.7224770642201835\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:616/1540: \n",
"\r",
"AccMetric: acc=0.7534403669724771\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.41 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:770/1540: \n",
"\r",
"AccMetric: acc=0.7396788990825688\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.22 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:924/1540: \n",
"\r",
"AccMetric: acc=0.7442660550458715\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.45 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:1078/1540: \n",
"\r",
"AccMetric: acc=0.6903669724770642\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.25 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:1232/1540: \n",
"\r",
"AccMetric: acc=0.7293577981651376\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.4 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:1386/1540: \n",
"\r",
"AccMetric: acc=0.7006880733944955\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:1540/1540: \n",
"\r",
"AccMetric: acc=0.7339449541284404\n",
"\n",
"\r\n",
"In Epoch:4/Step:616, got best dev performance:\n",
"AccMetric: acc=0.7534403669724771\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccMetric': {'acc': 0.7534403669724771}},\n",
" 'best_epoch': 4,\n",
" 'best_step': 616,\n",
" 'seconds': 34.74}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = Trainer(train_data=train_data, dev_data=dev_data, model=model,\n",
" loss=loss, device=device, metrics=AccMetric())\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``MetricBase`` 将会在输入的字典 ``pred_dict`` 和 ``target_dict`` 中进行检查.\n",
"``pred_dict`` 是模型当中 ``forward()`` 函数或者 ``predict()`` 函数的返回值.\n",
"``target_dict`` 是DataSet当中的ground truth, 判定ground truth的条件是field的 ``is_target`` 被设置为True.\n",
"\n",
"``MetricBase`` 会进行以下的类型检测:\n",
"\n",
"1. self.evaluate当中是否有 varargs, 这是不支持的.\n",
"2. self.evaluate当中所需要的参数是否既不在 ``pred_dict`` 也不在 ``target_dict`` .\n",
"3. self.evaluate当中所需要的参数是否既在 ``pred_dict`` 也在 ``target_dict`` .\n",
"\n",
"除此以外在参数被传入self.evaluate以前这个函数会检测 ``pred_dict`` 和 ``target_dict`` 当中没有被用到的参数\n",
"如果kwargs是self.evaluate的参数则不会检测\n",
"\n",
"self.evaluate将计算一个批次(batch)的评价指标,并累计。 没有返回值\n",
"self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称value是指标的值\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python Now",
"language": "python",
"name": "now"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}