fastNLP/tutorials/tutorial_6_datasetiter.ipynb

682 lines
24 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用Trainer和Tester快速训练和测试"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据读入和处理"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\ttest has 1821 instances.\n",
"\ttrain has 67349 instances.\n",
"\tdev has 872 instances.\n",
"In total 2 vocabs:\n",
"\twords has 16292 entries.\n",
"\ttarget has 2 entries.\n",
"\n",
"+-----------------------------------+--------+-----------------------------------+---------+\n",
"| raw_words | target | words | seq_len |\n",
"+-----------------------------------+--------+-----------------------------------+---------+\n",
"| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n",
"+-----------------------------------+--------+-----------------------------------+---------+\n",
"Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n"
]
}
],
"source": [
"from fastNLP.io import SST2Pipe\n",
"\n",
"pipe = SST2Pipe()\n",
"databundle = pipe.process_from_file()\n",
"vocab = databundle.get_vocab('words')\n",
"print(databundle)\n",
"print(databundle.get_dataset('train')[0])\n",
"print(databundle.get_vocab('words'))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4925 872 75\n"
]
}
],
"source": [
"train_data = databundle.get_dataset('train')[:5000]\n",
"train_data, test_data = train_data.split(0.015)\n",
"dev_data = databundle.get_dataset('dev')\n",
"print(len(train_data),len(dev_data),len(test_data))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+-------------+-----------+--------+-------+---------+\n",
"| field_names | raw_words | target | words | seq_len |\n",
"+-------------+-----------+--------+-------+---------+\n",
"| is_input | False | False | True | True |\n",
"| is_target | False | True | False | False |\n",
"| ignore_type | | False | False | False |\n",
"| pad_value | | 0 | 0 | 0 |\n",
"+-------------+-----------+--------+-------+---------+\n"
]
},
{
"data": {
"text/plain": [
"<prettytable.PrettyTable at 0x7f0db03d0640>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.print_field_meta()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import AccuracyMetric\n",
"from fastNLP import Const\n",
"\n",
"# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n",
"metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DataSetIter初探"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
" 1323, 4398, 7],\n",
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
"batch_y: {'target': tensor([1, 0])}\n",
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
"batch_y: {'target': tensor([0, 1])}\n",
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
" [15618, 3204, 5, 1675, 0]]), 'seq_len': tensor([5, 4])}\n",
"batch_y: {'target': tensor([1, 1])}\n",
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
"batch_y: {'target': tensor([0, 0])}\n",
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 12])}\n",
"batch_y: {'target': tensor([0, 1])}\n"
]
}
],
"source": [
"from fastNLP import BucketSampler\n",
"from fastNLP import DataSetIter\n",
"\n",
"tmp_data = dev_data[:10]\n",
"# 定义一个Batch传入DataSet规定batch_size和去batch的规则。\n",
"# 顺序Sequential随机Random相似长度组成一个batchBucket\n",
"sampler = BucketSampler(batch_size=2, seq_len_field_name='seq_len')\n",
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
"for batch_x, batch_y in batch:\n",
" print(\"batch_x: \",batch_x)\n",
" print(\"batch_y: \", batch_y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
" 1323, 4398, 7],\n",
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
" 7, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
" -1, -1, -1]]), 'seq_len': tensor([33, 21])}\n",
"batch_y: {'target': tensor([1, 0])}\n",
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
"batch_y: {'target': tensor([0, 1])}\n",
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
"batch_y: {'target': tensor([0, 0])}\n",
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
" [15618, 3204, 5, 1675, -1]]), 'seq_len': tensor([5, 4])}\n",
"batch_y: {'target': tensor([1, 1])}\n",
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
" 1217, 7, -1, -1, -1, -1, -1, -1, -1, -1]]), 'seq_len': tensor([20, 12])}\n",
"batch_y: {'target': tensor([0, 1])}\n"
]
}
],
"source": [
"tmp_data.set_pad_val('words',-1)\n",
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
"for batch_x, batch_y in batch:\n",
" print(\"batch_x: \",batch_x)\n",
" print(\"batch_y: \", batch_y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_x: {'words': tensor([[ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([12, 20])}\n",
"batch_y: {'target': tensor([1, 0])}\n",
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
" 1323, 4398, 7, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
"batch_y: {'target': tensor([1, 0])}\n",
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0],\n",
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0]]), 'seq_len': tensor([9, 9])}\n",
"batch_y: {'target': tensor([0, 1])}\n",
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 20])}\n",
"batch_y: {'target': tensor([0, 0])}\n",
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [15618, 3204, 5, 1675, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([5, 4])}\n",
"batch_y: {'target': tensor([1, 1])}\n"
]
}
],
"source": [
"from fastNLP.core.field import Padder\n",
"import numpy as np\n",
"class FixLengthPadder(Padder):\n",
" def __init__(self, pad_val=0, length=None):\n",
" super().__init__(pad_val=pad_val)\n",
" self.length = length\n",
" assert self.length is not None, \"Creating FixLengthPadder with no specific length!\"\n",
"\n",
" def __call__(self, contents, field_name, field_ele_dtype, dim):\n",
" #计算当前contents中的最大长度\n",
" max_len = max(map(len, contents))\n",
" #如果当前contents中的最大长度大于指定的padder length的话就报错\n",
" assert max_len <= self.length, \"Fixed padder length smaller than actual length! with length {}\".format(max_len)\n",
" array = np.full((len(contents), self.length), self.pad_val, dtype=field_ele_dtype)\n",
" for i, content_i in enumerate(contents):\n",
" array[i, :len(content_i)] = content_i\n",
" return array\n",
"\n",
"#设定FixLengthPadder的固定长度为40\n",
"tmp_padder = FixLengthPadder(pad_val=0,length=40)\n",
"#利用dataset的set_padder函数设定words field的padder\n",
"tmp_data.set_padder('words',tmp_padder)\n",
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
"for batch_x, batch_y in batch:\n",
" print(\"batch_x: \",batch_x)\n",
" print(\"batch_y: \", batch_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 使用DataSetIter自己编写训练过程\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-----start training-----\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 2.68 seconds!\n",
"Epoch 0 Avg Loss: 0.66 AccuracyMetric: acc=0.708716 29307ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.38 seconds!\n",
"Epoch 1 Avg Loss: 0.41 AccuracyMetric: acc=0.770642 52200ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.51 seconds!\n",
"Epoch 2 Avg Loss: 0.16 AccuracyMetric: acc=0.747706 70268ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.96 seconds!\n",
"Epoch 3 Avg Loss: 0.06 AccuracyMetric: acc=0.741972 90349ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 1.04 seconds!\n",
"Epoch 4 Avg Loss: 0.03 AccuracyMetric: acc=0.740826 114250ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.8 seconds!\n",
"Epoch 5 Avg Loss: 0.02 AccuracyMetric: acc=0.738532 134742ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.65 seconds!\n",
"Epoch 6 Avg Loss: 0.01 AccuracyMetric: acc=0.731651 154503ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.8 seconds!\n",
"Epoch 7 Avg Loss: 0.01 AccuracyMetric: acc=0.738532 175397ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.36 seconds!\n",
"Epoch 8 Avg Loss: 0.01 AccuracyMetric: acc=0.733945 192384ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.84 seconds!\n",
"Epoch 9 Avg Loss: 0.01 AccuracyMetric: acc=0.744266 214417ms\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.04 seconds!\n",
"[tester] \n",
"AccuracyMetric: acc=0.786667\n"
]
},
{
"data": {
"text/plain": [
"{'AccuracyMetric': {'acc': 0.786667}}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fastNLP import BucketSampler\n",
"from fastNLP import DataSetIter\n",
"from fastNLP.models import CNNText\n",
"from fastNLP import Tester\n",
"import torch\n",
"import time\n",
"\n",
"embed_dim = 100\n",
"model = CNNText((len(vocab),embed_dim), num_classes=2, dropout=0.1)\n",
"\n",
"def train(epoch, data, devdata):\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
" lossfunc = torch.nn.CrossEntropyLoss()\n",
" batch_size = 32\n",
"\n",
" # 定义一个Batch传入DataSet规定batch_size和去batch的规则。\n",
" # 顺序Sequential随机Random相似长度组成一个batchBucket\n",
" train_sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')\n",
" train_batch = DataSetIter(batch_size=batch_size, dataset=data, sampler=train_sampler)\n",
"\n",
" start_time = time.time()\n",
" print(\"-\"*5+\"start training\"+\"-\"*5)\n",
" for i in range(epoch):\n",
" loss_list = []\n",
" for batch_x, batch_y in train_batch:\n",
" optimizer.zero_grad()\n",
" output = model(batch_x['words'])\n",
" loss = lossfunc(output['pred'], batch_y['target'])\n",
" loss.backward()\n",
" optimizer.step()\n",
" loss_list.append(loss.item())\n",
"\n",
" #这里verbose如果为0在调用Tester对象的test()函数时不输出任何信息,返回评估信息; 如果为1打印出验证结果返回评估信息\n",
" #在调用过Tester对象的test()函数后调用其_format_eval_results(res)函数,结构化输出验证结果\n",
" tester_tmp = Tester(devdata, model, metrics=AccuracyMetric(), verbose=0)\n",
" res=tester_tmp.test()\n",
"\n",
" print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)),end=\" \")\n",
" print(tester_tmp._format_eval_results(res),end=\" \")\n",
" print('{:d}ms'.format(round((time.time()-start_time)*1000)))\n",
" loss_list.clear()\n",
"\n",
"train(10, train_data, dev_data)\n",
"#使用tester进行快速测试\n",
"tester = Tester(test_data, model, metrics=AccuracyMetric())\n",
"tester.test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python Now",
"language": "python",
"name": "now"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}