mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
eca0c6f8c4
* pass all tests * prepare CWS & POS API * update tutorials * add README.md in tutorials/ & api/
1170 lines
40 KiB
Plaintext
1170 lines
40 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# fastNLP开发进阶教程\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 组织数据部分\n",
|
||
"## DataSet & Instance\n",
|
||
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
||
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 声明部件\n",
|
||
"import torch\n",
|
||
"import fastNLP\n",
|
||
"from fastNLP import DataSet\n",
|
||
"from fastNLP import Instance\n",
|
||
"from fastNLP import Vocabulary\n",
|
||
"from fastNLP import Trainer\n",
|
||
"from fastNLP import Tester"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Instance\n",
|
||
"Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值\n",
|
||
"在初始化Instance的时候可以定义它包含的field,使用\"field_name=field_value\"的写法"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'premise': an premise example . type=str,\n",
|
||
"'hypothesis': an hypothesis example. type=str,\n",
|
||
"'label': 1 type=int}"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成\n",
|
||
"instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1)\n",
|
||
"instance"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': an premise example . type=str,\n",
|
||
"'hypothesis': an hypothesis example. type=str,\n",
|
||
"'label': 1 type=int},\n",
|
||
"{'premise': an premise example . type=str,\n",
|
||
"'hypothesis': an hypothesis example. type=str,\n",
|
||
"'label': 1 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"data_set = DataSet([instance] * 5)\n",
|
||
"data_set.append(instance)\n",
|
||
"data_set[-2: ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': an premise example . type=str,\n",
|
||
"'hypothesis': an hypothesis example. type=str,\n",
|
||
"'label': 1 type=int},\n",
|
||
"{'premise': the second premise example . type=str,\n",
|
||
"'hypothesis': the second hypothesis example. type=str,\n",
|
||
"'label': 1 type=str})"
|
||
]
|
||
},
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中\n",
|
||
"instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', label='1')\n",
|
||
"try:\n",
|
||
" data_set.append(instance2)\n",
|
||
"except:\n",
|
||
" pass\n",
|
||
"data_set[-2: ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"cannot append instance\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': an premise example . type=str,\n",
|
||
"'hypothesis': an hypothesis example. type=str,\n",
|
||
"'label': 1 type=int},\n",
|
||
"{'premise': the second premise example . type=str,\n",
|
||
"'hypothesis': the second hypothesis example. type=str,\n",
|
||
"'label': 1 type=str})"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 如果某一个field的名字不对,则该instance不能被append到dataset中\n",
|
||
"instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', label=1)\n",
|
||
"try:\n",
|
||
" data_set.append(instance3)\n",
|
||
"except:\n",
|
||
" print('cannot append instance')\n",
|
||
" pass\n",
|
||
"data_set[-2: ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'image': tensor([[ 2.1747, -1.0147, -1.3853, 0.0216, -0.4957],\n",
|
||
" [ 0.8138, -0.2933, -0.1217, -0.6027, 0.3932],\n",
|
||
" [ 0.6750, -1.1136, -1.3371, -0.0185, -0.3206],\n",
|
||
" [-0.5076, -0.3822, 0.1719, -0.6447, -0.5702],\n",
|
||
" [ 0.3804, 0.0889, 0.8027, -0.7121, -0.7320]]) type=torch.Tensor,\n",
|
||
"'label': 0 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 除了文本以外,还可以将tensor作为其中一个field的value\n",
|
||
"import torch\n",
|
||
"tensor_ins = Instance(image=torch.randn(5, 5), label=0)\n",
|
||
"ds = DataSet()\n",
|
||
"ds.append(tensor_ins)\n",
|
||
"ds"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### DataSet\n",
|
||
"### 使用现有代码读取并组织DataSet\n",
|
||
"在DataSet类当中有一些read_* 方法,可以从文件中读取数据并组织DataSet"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"77"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import fastNLP\n",
|
||
"from fastNLP import DataSet\n",
|
||
"from fastNLP import Instance\n",
|
||
"\n",
|
||
"# 从csv读取数据到DataSet\n",
|
||
"# 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取\n",
|
||
"dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
|
||
"# 查看DataSet的大小\n",
|
||
"len(dataset)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n",
|
||
"'label': 1 type=str}"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 使用数字索引[k],获取第k个样本\n",
|
||
"dataset[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"fastNLP.core.instance.Instance"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 获取的样本是一个Instance\n",
|
||
"type(dataset[0])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n",
|
||
"'label': 1 type=str},\n",
|
||
"{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n",
|
||
"'label': 4 type=str},\n",
|
||
"{'raw_sentence': Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . type=str,\n",
|
||
"'label': 1 type=str})"
|
||
]
|
||
},
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 使用数字索引[a: b],获取第a到第b个样本\n",
|
||
"dataset[0: 3]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'raw_sentence': A film that clearly means to preach exclusively to the converted . type=str,\n",
|
||
"'label': 2 type=str}"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 索引也可以是负数\n",
|
||
"dataset[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 自行读取并组织DataSet\n",
|
||
"以SNLI数据集为例,\n",
|
||
"SNLI数据集的训练、验证、测试集分别三个文件组成:第一个文件每一行是一句话,代表一个example当中的premise;第二个文件每一行也是一句话,代表一个example当中的hypothesis;第三个文件每一行是一个label"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'premise': A person on a horse jumps over a broken down airplane . type=str,\n",
|
||
"'hypothesis': A person is training his horse for a competition . type=str,\n",
|
||
"'truth': 1\n",
|
||
" type=str}"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"data_path = ['premise', 'hypothesis', 'label']\n",
|
||
"\n",
|
||
"# 读入文件\n",
|
||
"with open(data_path[0]) as f:\n",
|
||
" premise = f.readlines()\n",
|
||
"\n",
|
||
"with open(data_path[1]) as f:\n",
|
||
" hypothesis = f.readlines()\n",
|
||
"\n",
|
||
"with open(data_path[2]) as f:\n",
|
||
" label = f.readlines()\n",
|
||
"\n",
|
||
"assert len(premise) == len(hypothesis) and len(hypothesis) == len(label)\n",
|
||
"\n",
|
||
"# 组织DataSet\n",
|
||
"data_set = DataSet()\n",
|
||
"for p, h, l in zip(premise, hypothesis, label):\n",
|
||
" p = p.strip() # 将行末空格去除\n",
|
||
" h = h.strip() # 将行末空格去除\n",
|
||
" data_set.append(Instance(premise=p, hypothesis=h, truth=l))\n",
|
||
"\n",
|
||
"data_set[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### DataSet的其他操作\n",
|
||
"在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
|
||
"'hypothesis': An actress and her favorite assistant talk a walk in the city . type=str,\n",
|
||
"'truth': 1\n",
|
||
" type=str},\n",
|
||
"{'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
|
||
"'hypothesis': a woman eating a banana crosses a street type=str,\n",
|
||
"'truth': 0\n",
|
||
" type=str})"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 将premise域的所有文本转成小写\n",
|
||
"data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise')\n",
|
||
"data_set[-2: ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
|
||
"'hypothesis': An actress and her favorite assistant talk a walk in the city . type=str,\n",
|
||
"'truth': 1 type=int},\n",
|
||
"{'premise': a woman is walking across the street eating a banana , while a man is following with his briefcase . type=str,\n",
|
||
"'hypothesis': a woman eating a banana crosses a street type=str,\n",
|
||
"'truth': 0 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# label转int\n",
|
||
"data_set.apply(lambda x: int(x['truth']), new_field_name='truth')\n",
|
||
"data_set[-2: ]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"DataSet({'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
|
||
"'hypothesis': ['An', 'actress', 'and', 'her', 'favorite', 'assistant', 'talk', 'a', 'walk', 'in', 'the', 'city', '.'] type=list,\n",
|
||
"'truth': 1 type=int},\n",
|
||
"{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
|
||
"'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
|
||
"'truth': 0 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 使用空格分割句子\n",
|
||
"def split_sent(ins):\n",
|
||
" return ins['premise'].split()\n",
|
||
"data_set.apply(split_sent, new_field_name='premise')\n",
|
||
"data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis')\n",
|
||
"data_set[-2:]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(100, 97)"
|
||
]
|
||
},
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 筛选数据\n",
|
||
"origin_data_set_len = len(data_set)\n",
|
||
"data_set.drop(lambda x: len(x['premise']) <= 6)\n",
|
||
"origin_data_set_len, len(data_set)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
|
||
"'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
|
||
"'truth': 0 type=int,\n",
|
||
"'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
"'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list}"
|
||
]
|
||
},
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 增加长度信息\n",
|
||
"data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len')\n",
|
||
"data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len')\n",
|
||
"data_set[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 设定特征域、标签域\n",
|
||
"data_set.set_input(\"premise\", \"premise_len\", \"hypothesis\", \"hypothesis_len\")\n",
|
||
"data_set.set_target(\"truth\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
|
||
"'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
|
||
"'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
"'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
"'label': 0 type=int}"
|
||
]
|
||
},
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 重命名field\n",
|
||
"data_set.rename_field('truth', 'label')\n",
|
||
"data_set[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(49, 29, 19)"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 切分训练、验证集、测试集\n",
|
||
"train_data, vad_data = data_set.split(0.5)\n",
|
||
"dev_data, test_data = vad_data.split(0.4)\n",
|
||
"len(train_data), len(dev_data), len(test_data)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 深拷贝一个数据集\n",
|
||
"import copy\n",
|
||
"train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data)\n",
|
||
"del copy"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### DataSet的总结:\n",
|
||
"将DataSet的ield设置为input和target后,这些field在接下来的代码中将被使用。其中被设置为input的field会被传递给<strong>Model.forward</strong>,这个过程是通过键匹配的方式进行的。举例如下: \n",
|
||
"假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 \n",
|
||
"   (1)函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 \n",
|
||
"   (2)函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 \n",
|
||
"   (3)函数是Model.forward(self, x1, kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, kwargs)就会发生报错,因为没有'x4'。 \n",
|
||
"   (4)对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。如果这个名称不是target,那么在加载loss函数的时候需要手动添加名称转换map"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Vocabulary\n",
|
||
"fastNLP中的Vocabulary轻松构建词表,并将词转成数字。构建词表有两种方式:根据数据集构建词表;载入现有词表\n",
|
||
"### 根据数据集构建词表"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,'<unk>'表示未知词语,'<pad>'表示padding词语\n",
|
||
"# Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'\n",
|
||
"vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>')\n",
|
||
"\n",
|
||
"# 构建词表\n",
|
||
"train_data.apply(lambda x: [vocab.add(word) for word in x['premise']])\n",
|
||
"train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']])\n",
|
||
"vocab.build_vocab()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"({'premise': [2, 145, 146, 80, 147, 26, 148, 2, 104, 149, 150, 2, 151, 5, 55, 152, 105, 3] type=list,\n",
|
||
" 'hypothesis': [22, 80, 8, 1, 1, 20, 1, 3] type=list,\n",
|
||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'label': 2 type=int},\n",
|
||
" {'premise': [11, 5, 18, 5, 24, 6, 2, 10, 59, 52, 14, 9, 2, 53, 29, 60, 54, 45, 6, 46, 5, 7, 61, 3] type=list,\n",
|
||
" 'hypothesis': [22, 11, 1, 45, 3] type=list,\n",
|
||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n",
|
||
" 'label': 1 type=int},\n",
|
||
" {'premise': [2, 11, 8, 14, 16, 7, 15, 50, 2, 66, 4, 76, 2, 10, 8, 98, 9, 58, 67, 3] type=list,\n",
|
||
" 'hypothesis': [22, 27, 50, 3] type=list,\n",
|
||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'hypothesis_len': [1, 1, 1, 1] type=list,\n",
|
||
" 'label': 0 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 根据词表index句子\n",
|
||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
|
||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
|
||
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
|
||
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
|
||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
|
||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
|
||
"train_data[-1], dev_data[-1], test_data[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 载入现有词表\n",
|
||
"以BERT pretrained model为例,词表由一个vocab.txt文件来保存\n",
|
||
"用以下方法可以载入现有词表,并保证词表顺序不变"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 读入vocab文件\n",
|
||
"with open('vocab.txt') as f:\n",
|
||
" lines = f.readlines()\n",
|
||
"vocabs = []\n",
|
||
"for line in lines:\n",
|
||
" vocabs.append(line.strip())\n",
|
||
"\n",
|
||
"# 实例化Vocabulary\n",
|
||
"vocab_bert = Vocabulary(unknown=None, padding=None)\n",
|
||
"# 将vocabs列表加入Vocabulary\n",
|
||
"vocab_bert.add_word_lst(vocabs)\n",
|
||
"# 构建词表\n",
|
||
"vocab_bert.build_vocab()\n",
|
||
"# 更新unknown与padding的token文本\n",
|
||
"vocab_bert.unknown = '[UNK]'\n",
|
||
"vocab_bert.padding = '[PAD]'"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"({'premise': [1037, 2210, 2223, 2136, 5363, 2000, 4608, 1037, 5479, 8058, 2046, 1037, 2918, 1999, 2019, 5027, 2208, 1012] type=list,\n",
|
||
" 'hypothesis': [100, 2136, 2003, 2652, 3598, 2006, 100, 1012] type=list,\n",
|
||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'label': 2 type=int},\n",
|
||
" {'premise': [2450, 1999, 2317, 1999, 100, 1998, 1037, 2158, 3621, 2369, 3788, 2007, 1037, 3696, 2005, 2198, 100, 10733, 1998, 100, 1999, 1996, 4281, 1012] type=list,\n",
|
||
" 'hypothesis': [100, 2450, 13063, 10733, 1012] type=list,\n",
|
||
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
|
||
" 'hypothesis_len': [1, 1, 1, 1, 1] type=list,\n",
|
||
" 'label': 1 type=int})"
|
||
]
|
||
},
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 根据词表index句子\n",
|
||
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise')\n",
|
||
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
|
||
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise')\n",
|
||
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
|
||
"train_data_2[-1], dev_data_2[-1]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 模型部分\n",
|
||
"## Model\n",
|
||
"模型部分fastNLP提供两种使用方式:调用fastNLP现有模型;开发者自行搭建模型\n",
|
||
"### 调用fastNLP现有模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'embed_dim': 300,\n",
|
||
" 'hidden_size': 300,\n",
|
||
" 'batch_first': True,\n",
|
||
" 'dropout': 0.3,\n",
|
||
" 'num_classes': 3,\n",
|
||
" 'gpu': True,\n",
|
||
" 'batch_size': 32,\n",
|
||
" 'vocab_size': 165}"
|
||
]
|
||
},
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# step 1:加载模型参数(非必选)\n",
|
||
"from fastNLP.io.config_io import ConfigSection, ConfigLoader\n",
|
||
"args = ConfigSection()\n",
|
||
"ConfigLoader().load_config(\"./data/config\", {\"esim_model\": args})\n",
|
||
"args[\"vocab_size\"] = len(vocab)\n",
|
||
"args.data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"ESIM(\n",
|
||
" (drop): Dropout(p=0.3)\n",
|
||
" (embedding): Embedding(\n",
|
||
" (embed): Embedding(165, 300, padding_idx=0)\n",
|
||
" (dropout): Dropout(p=0.3)\n",
|
||
" )\n",
|
||
" (embedding_layer): Linear(\n",
|
||
" (linear): Linear(in_features=300, out_features=300, bias=True)\n",
|
||
" )\n",
|
||
" (encoder): LSTM(\n",
|
||
" (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
|
||
" )\n",
|
||
" (bi_attention): Bi_Attention()\n",
|
||
" (mean_pooling): MeanPoolWithMask()\n",
|
||
" (max_pooling): MaxPoolWithMask()\n",
|
||
" (inference_layer): Linear(\n",
|
||
" (linear): Linear(in_features=1200, out_features=300, bias=True)\n",
|
||
" )\n",
|
||
" (decoder): LSTM(\n",
|
||
" (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
|
||
" )\n",
|
||
" (output): MLP(\n",
|
||
" (hiddens): ModuleList(\n",
|
||
" (0): Linear(in_features=1200, out_features=300, bias=True)\n",
|
||
" )\n",
|
||
" (output): Linear(in_features=300, out_features=3, bias=True)\n",
|
||
" (dropout): Dropout(p=0.3)\n",
|
||
" (hidden_active): Tanh()\n",
|
||
" )\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# step 2:加载ESIM模型\n",
|
||
"from fastNLP.models import ESIM\n",
|
||
"model = ESIM(**args.data)\n",
|
||
"model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"CNNText(\n",
|
||
" (embed): Embedding(\n",
|
||
" (embed): Embedding(165, 50, padding_idx=0)\n",
|
||
" (dropout): Dropout(p=0.0)\n",
|
||
" )\n",
|
||
" (conv_pool): ConvMaxpool(\n",
|
||
" (convs): ModuleList(\n",
|
||
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
|
||
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
|
||
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (dropout): Dropout(p=0.1)\n",
|
||
" (fc): Linear(\n",
|
||
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
|
||
" )\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 另一个例子:加载CNN文本分类模型\n",
|
||
"from fastNLP.models import CNNText\n",
|
||
"cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
|
||
"cnn_text_model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n",
|
||
"\n",
|
||
"注意两点:\n",
|
||
"1. forward参数名字叫**word_seq**,请记住。\n",
|
||
"2. forward的返回值是一个**dict**,其中有个key的名字叫**pred**。\n",
|
||
"\n",
|
||
"```Python\n",
|
||
" def forward(self, word_seq):\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
|
||
" :return output: dict of torch.LongTensor, [batch_size, num_classes]\n",
|
||
" \"\"\"\n",
|
||
" x = self.embed(word_seq) # [N,L] -> [N,L,C]\n",
|
||
" x = self.conv_pool(x) # [N,L,C] -> [N,C]\n",
|
||
" x = self.dropout(x)\n",
|
||
" x = self.fc(x) # [N,C] -> [N, N_class]\n",
|
||
" return {'pred': x}\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 自行搭载模型\n",
|
||
"自行搭载的模型必须是nn.Module的子类, \n",
|
||
"(1)必须实现forward方法,并且forward方法不能出现**\\*arg**这种参数,例如\n",
|
||
"```Python\n",
|
||
" def forword(self, word_seq, *args): # 这是不允许的\n",
|
||
" xxx\n",
|
||
"```\n",
|
||
"forward函数的返回值必须是一个**dict**。 \n",
|
||
"dict当中模型预测的值所对应的key建议用**'pred'**,这里不做强制限制,但是如果不是pred的话,在加载loss函数的时候需要手动添加名称转换map \n",
|
||
"(2)如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用\\*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。 \n",
|
||
"(3)forward函数可以计算loss并返回结果,在dict中的key为'loss',如: \n",
|
||
"```Python\n",
|
||
" def forword(self, word_seq, *args): \n",
|
||
" xxx\n",
|
||
" return {'pred': pred, 'loss': loss}\n",
|
||
"```\n",
|
||
"当loss函数没有在trainer里被定义的时候,trainer将会根据forward函数返回的dict中key为'loss'的值来进行反向传播,具体见loss部分"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 训练测试部分\n",
|
||
"## Loss\n",
|
||
"### 键映射\n",
|
||
"根据上文DataSet与Model部分可以知道,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。因此在计算loss的时候,需要通过键映射的方式来完成取值。 \n",
|
||
"这里以CrossEntropyLoss为例,我们的交叉熵函数部分如下:\n",
|
||
"```Python\n",
|
||
" def get_loss(self, pred, target):\n",
|
||
" return F.cross_entropy(input=pred, target=target,\n",
|
||
" ignore_index=self.padding_idx)\n",
|
||
"```\n",
|
||
"这里接收的两个参数名字分别为pred和target,其中pred是从model的forward函数返回值中取得,target是从DataSet的is_target的field当中取得。在没有设置键映射的基础上,pred从model的forward函数返回的dict中取'pred'键得到;target从DataSet的'target'field中得到。\n",
|
||
"。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 修改键映射\n",
|
||
"在初始化CrossEntropyLoss的时候,可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与d ataset的is target field中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与dataset的is target field,也可以只来自于forward的结果"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### 创建一个自己的loss\n",
|
||
"    (1)采用fastNLP.LossInForward,在model.forward()的结果中包含一个'loss'的key,具体见**自行搭载模型**部分 \n",
|
||
"    (2)自己定义一个继承于fastNLP.core.loss.LossBase的class。重写get_loss方法。 \n",
|
||
"      (2.1)在初始化自己的loss class的时候,需要初始化需要映射的值 \n",
|
||
"      (2.2)在get_loss函数中,参数的名字需要与初始化时的映射是一致的 \n",
|
||
"以L1Loss为例子:\n",
|
||
"```Python\n",
|
||
"class L1Loss(LossBase):\n",
|
||
" def __init__(self, pred=None, target=None):\n",
|
||
" super(L1Loss, self).__init__()\n",
|
||
" \"\"\"\n",
|
||
" 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n",
|
||
" \"键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,\n",
|
||
" 则\\不要将threshold传入_init_param_map.\n",
|
||
" \"\"\"\n",
|
||
" self._init_param_map(pred=pred, target=target)\n",
|
||
"\n",
|
||
" def get_loss(self, pred, target):\n",
|
||
" # 这里'pred', 'target'必须和初始化的映射是一致的。\n",
|
||
" return F.l1_loss(input=pred, target=target)\n",
|
||
"```\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Trainer\n",
|
||
"trainer的作用是训练模型,是一个对训练过程的封装。trainer当中比较重要的函数是trainer.train(),train函数的主要步骤包括: \n",
|
||
"(1)创建batch \n",
|
||
"```Python\n",
|
||
"batch = Batch(dataset, batch_size, sampler=sampler)\n",
|
||
"```\n",
|
||
"(2)for batch_x, batch_y in batch: (batch_x, batch_y的内容分别为dataset中is input和is target的部分,这两个dict的key就是DataSet中的key,value会根据情况做好padding及tensor) \n",
|
||
"  (2.1)将batch_x, batch_y中的tensor移动到model所在的device \n",
|
||
"  (2.2)根据model.forward的参数列表,从batch_x中取出需要传递给forward的数据 \n",
|
||
"  (2.3)获取model.forward返回的dict,并与batch_y一起传递给loss函数,求得loss \n",
|
||
"  (2.4)对loss进行反向梯度传播并更新参数 \n",
|
||
"(3)如果有验证集,则进行验证 \n",
|
||
"(4)如果验证集的结果是当前最佳结果,则保存模型"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"除了以上的内容, Trainer中还提供了\"预跑\"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行\"预跑\"。 check_code_level=0,1,2代表不同的提醒级别。目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。 0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行 \"预跑\"的主要目的有两个: \n",
|
||
"(1)防止train完了之后进行evaluation的时候出现错误。之前的train就白费了 \n",
|
||
"(2)由于存在\"键映射\",直接运行导致的报错可能不太容易debug,通过\"预跑\"过程的报错会有一些debug提示 \"预跑\"会进行以下的操作: \n",
|
||
"  (i) 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。 \n",
|
||
"  (ii)将Model.foward的输出pred_dict与batch_y输入到loss中,并尝试backward。不会更新参数,而且grad会被清零。如果传入了dev_data,还将进行metric的测试 \n",
|
||
"  (iii)创建Tester,并传入少量数据,检测是否可以正常运行 \n",
|
||
"\"预跑\"操作是在Trainer初始化的时候执行的。正常情况下,应该不需要改动\"预跑\"的代码。但如果遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"training epochs started 2019-01-09 00-08-17\n",
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.206897\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/remote-home/ygxu/anaconda3/envs/no-fastnlp/lib/python3.7/site-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
|
||
" warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.206897\n",
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.206897\n",
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.206897\n",
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.206897\n",
|
||
"\n",
|
||
"In Epoch:1/Step:4, got best dev performance:AccuracyMetric: acc=0.206897\n",
|
||
"Reloaded the best model.\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'best_eval': {'AccuracyMetric': {'acc': 0.206897}},\n",
|
||
" 'best_epoch': 1,\n",
|
||
" 'best_step': 4,\n",
|
||
" 'seconds': 0.79}"
|
||
]
|
||
},
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from fastNLP import CrossEntropyLoss\n",
|
||
"from fastNLP import Adam\n",
|
||
"from fastNLP import AccuracyMetric\n",
|
||
"trainer = Trainer(\n",
|
||
" train_data=train_data,\n",
|
||
" model=model,\n",
|
||
" loss=CrossEntropyLoss(pred='pred', target='label'),\n",
|
||
" metrics=AccuracyMetric(),\n",
|
||
" n_epochs=5,\n",
|
||
" batch_size=16,\n",
|
||
" print_every=-1,\n",
|
||
" validate_every=-1,\n",
|
||
" dev_data=dev_data,\n",
|
||
" use_cuda=True,\n",
|
||
" optimizer=Adam(lr=1e-3, weight_decay=0),\n",
|
||
" check_code_level=-1,\n",
|
||
" metric_key='acc',\n",
|
||
" use_tqdm=False,\n",
|
||
")\n",
|
||
"trainer.train()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Tester\n",
|
||
"tester的作用是训练模型,是一个对训练过程的封装。tester当中比较重要的函数是tester.test(),test函数的主要步骤包括:\n",
|
||
"(1)创建batch\n",
|
||
"```Python\n",
|
||
"batch = Batch(dataset, batch_size, sampler=sampler)\n",
|
||
"```\n",
|
||
"(2)for batch_x, batch_y in batch: (batch_x, batch_y的内容分别为dataset中is input和is target的部分,这两个dict的key就是DataSet中的key,value会根据情况做好padding及tensor) \n",
|
||
"  (2.1)同步数据与model将batch_x, batch_y中的tensor移动到model所在的device \n",
|
||
"  (2.2)根据predict_func的参数列表,从batch_x中取出需要传递给predict_func的数据,得到结果pred_dict \n",
|
||
"  (2.3)调用metric(pred_dict, batch_y) \n",
|
||
"  (2.4)当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[tester] \n",
|
||
"AccuracyMetric: acc=0.263158\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'AccuracyMetric': {'acc': 0.263158}}"
|
||
]
|
||
},
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"tester = Tester(\n",
|
||
" data=test_data,\n",
|
||
" model=model,\n",
|
||
" metrics=AccuracyMetric(),\n",
|
||
" batch_size=args[\"batch_size\"],\n",
|
||
")\n",
|
||
"tester.test()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|