fastNLP/tutorials/tutorial_9_callback.ipynb
2020-02-27 21:22:57 +08:00

623 lines
17 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用 Callback 自定义你的训练过程"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 什么是 Callback\n",
"- 使用 Callback \n",
"- 一些常用的 Callback\n",
"- 自定义实现 Callback"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"什么是Callback\n",
"------\n",
"\n",
"Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
"\n",
"fastNLP 中提供了很多常用的 Callback ,开箱即用。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"使用 Callback\n",
" ------\n",
"\n",
"使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:34:46.465871Z",
"start_time": "2019-09-17T07:34:30.648758Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\ttest has 1200 instances.\n",
"\ttrain has 9600 instances.\n",
"\tdev has 1200 instances.\n",
"In total 2 vocabs:\n",
"\tchars has 4409 entries.\n",
"\ttarget has 2 entries.\n",
"\n",
"training epochs started 2019-09-17-03-34-34\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.863333\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.886667\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n",
"\r\n",
"In Epoch:3/Step:900, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import (Callback, EarlyStopCallback,\n",
" Trainer, CrossEntropyLoss, AccuracyMetric)\n",
"from fastNLP.models import CNNText\n",
"import torch.cuda\n",
"\n",
"# prepare data\n",
"def get_data():\n",
" from fastNLP.io import ChnSentiCorpPipe as pipe\n",
" data = pipe().process_from_file()\n",
" print(data)\n",
" data.rename_field('chars', 'words')\n",
" train_data = data.datasets['train']\n",
" dev_data = data.datasets['dev']\n",
" test_data = data.datasets['test']\n",
" vocab = data.vocabs['words']\n",
" tgt_vocab = data.vocabs['target']\n",
" return train_data, dev_data, test_data, vocab, tgt_vocab\n",
"\n",
"# prepare model\n",
"train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
"\n",
"# define callback\n",
"callbacks=[EarlyStopCallback(5)]\n",
"\n",
"# pass callbacks to Trainer\n",
"def train_with_callback(cb_list):\n",
" trainer = Trainer(\n",
" device=device,\n",
" n_epochs=3,\n",
" model=model, \n",
" train_data=train_data, \n",
" dev_data=dev_data, \n",
" loss=CrossEntropyLoss(), \n",
" metrics=AccuracyMetric(), \n",
" callbacks=cb_list, \n",
" check_code_level=-1\n",
" )\n",
" trainer.train()\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP 中的 Callback\n",
"-------\n",
"fastNLP 中提供了很多常用的 Callback如梯度裁剪训练时早停和测试验证集fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:35:02.182727Z",
"start_time": "2019-09-17T07:34:49.443863Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-34-49\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.12 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.890833\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.8875\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8875\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.885\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.885\n",
"\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
"callbacks = [\n",
" EarlyStopCallback(5),\n",
" GradientClipCallback(clip_value=5, clip_type='value'),\n",
" EvaluateCallback(dev_data)\n",
"]\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"自定义 Callback\n",
"------\n",
"\n",
"这里我们以一个简单的 Callback作为例子它的作用是打印每一个 Epoch 平均训练 loss。\n",
"\n",
"#### 创建 Callback\n",
" \n",
"要自定义 Callback我们要实现一个类继承 fastNLP.Callback。\n",
"\n",
"这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
"\n",
"#### 指定 Callback 调用的阶段\n",
" \n",
"Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
"\n",
"这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
"\n",
"#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
" \n",
"为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs分别对应训练时的优化器当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
"\n",
"这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:43:10.907139Z",
"start_time": "2019-09-17T07:42:58.488177Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-42-58\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.883333\n",
"\n",
"Avg loss at epoch 1, 0.100254\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8775\n",
"\n",
"Avg loss at epoch 2, 0.183511\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.875833\n",
"\n",
"Avg loss at epoch 3, 0.257103\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.883333\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import Callback\n",
"from fastNLP import logger\n",
"\n",
"class MyCallBack(Callback):\n",
" \"\"\"Print average loss in each epoch\"\"\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.total_loss = 0\n",
" self.start_step = 0\n",
" \n",
" def on_backward_begin(self, loss):\n",
" self.total_loss += loss.item()\n",
" \n",
" def on_epoch_end(self):\n",
" n_steps = self.step - self.start_step\n",
" avg_loss = self.total_loss / n_steps\n",
" logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
" self.start_step = self.step\n",
"\n",
"callbacks = [MyCallBack()]\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}