Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
yh 2022-05-17 22:42:11 +08:00
commit d5a4c7aaaa
4 changed files with 814 additions and 212 deletions

View File

@ -434,7 +434,7 @@
"\n",
"  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n",
"\n",
"    但对于`\"auto\"`和`\"rich\"`格式,训练结束后进度条会不显示(???)\n",
"    但对于`\"auto\"`和`\"rich\"`格式,在notebook中进度条在训练结束后会被丢弃\n",
"\n",
"  通过`n_epochs`设定优化迭代轮数默认为20全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询"
]
@ -523,19 +523,6 @@
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
@ -600,7 +587,7 @@
"\n",
"&emsp; 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n",
"\n",
"&emsp; 最终,输出形如`{'acc#acc': acc}`的字典,中间的进度条会在运行结束后丢弃掉(???)"
"&emsp; 最终,输出形如`{'acc#acc': acc}`的字典,在notebook中进度条在评测结束后会被丢弃"
]
},
{
@ -626,34 +613,11 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.41</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">41.0</span><span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.39</span><span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.39\u001b[0m\u001b[1m}\u001b[0m\n"
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.41\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m41.0\u001b[0m\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
@ -662,7 +626,7 @@
{
"data": {
"text/plain": [
"{'acc#acc': 0.39}"
"{'acc#acc': 0.41, 'total#acc': 100.0, 'correct#acc': 41.0}"
]
},
"execution_count": 9,
@ -683,7 +647,7 @@
"\n",
"通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n",
"\n",
"&emsp; 通过`progress_bar`同时设定训练和评估进度条格式,训练结束后进度条会不显示(???)\n",
"&emsp; 通过`progress_bar`同时设定训练和评估进度条格式,在notebook中在进度条训练结束后会被丢弃\n",
"\n",
"&emsp; **通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n",
"\n",
@ -755,19 +719,6 @@
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
@ -788,9 +739,40 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"id": "c4e9c619",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.46, 'total#acc': 100.0, 'correct#acc': 46.0}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db784d5b",
"metadata": {},
"outputs": [],
"source": []
}

View File

@ -153,7 +153,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1630555358408 1630228349768\n",
"2492313174344 2491986424200\n",
"+-----+------------------------+------------------------+-----+\n",
"| idx | sentence | words | num |\n",
"+-----+------------------------+------------------------+-----+\n",
@ -198,7 +198,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1630228349768 1630228349768\n",
"2491986424200 2491986424200\n",
"+-----+------------------------+------------------------+-----+\n",
"| idx | sentence | words | num |\n",
"+-----+------------------------+------------------------+-----+\n",
@ -680,9 +680,9 @@
{
"data": {
"text/plain": [
"{'sentence': <fastNLP.core.dataset.field.FieldArray at 0x17ba4b3a648>,\n",
" 'words': <fastNLP.core.dataset.field.FieldArray at 0x17ba4b3a6c8>,\n",
" 'num': <fastNLP.core.dataset.field.FieldArray at 0x17ba4b3a748>}"
"{'sentence': <fastNLP.core.dataset.field.FieldArray at 0x2444977fe88>,\n",
" 'words': <fastNLP.core.dataset.field.FieldArray at 0x2444977ff08>,\n",
" 'num': <fastNLP.core.dataset.field.FieldArray at 0x2444977ff88>}"
]
},
"execution_count": 15,

View File

@ -4,36 +4,28 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# T2. dataloader 和 tokenizer 的基本使用\n",
"# T2. databundle 和 tokenizer 的基本使用\n",
"\n",
"&emsp; 1 &ensp; fastNLP 中的 dataloader\n",
"&emsp; 1 &ensp; fastNLP 中 dataset 的延伸\n",
"\n",
"&emsp; &emsp; 1.1 &ensp; databundle 的结构与使用\n",
"\n",
"&emsp; &emsp; 1.2 &ensp; dataloader 的结构与使用\n",
"&emsp; &emsp; 1.1 &ensp; databundle 的概念与使用\n",
"\n",
"&emsp; 2 &ensp; fastNLP 中的 tokenizer\n",
" \n",
"&emsp; &emsp; 2.1 &ensp; 传统 GloVe 词嵌入的加载\n",
"&emsp; &emsp; 2.1 &ensp; PreTrainedTokenizer 的概念\n",
"\n",
"&emsp; &emsp; 2.2 &ensp; BertTokenizer 的基本使用\n",
" \n",
"&emsp; &emsp; 2.2 &ensp; PreTrainedTokenizer 的概念\n",
"\n",
"&emsp; &emsp; 2.3 &ensp; BertTokenizer 的基本使用\n",
"\n",
"&emsp; 3 &ensp; 实例NG20 数据集的完整加载过程\n",
" \n",
"&emsp; &emsp; 3.1 &ensp; \n",
"\n",
"&emsp; &emsp; 3.2 &ensp; "
"&emsp; &emsp; 2.3 &ensp; 补充GloVe 词嵌入的使用"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. fastNLP 中的 dataloader\n",
"## 1. fastNLP 中 dataset 的延伸\n",
"\n",
"### 1.1 databundle 的结构与使用\n",
"### 1.1 databundle 的概念与使用\n",
"\n",
"在`fastNLP 0.8`中,在常用的数据加载模块`DataLoader`和数据集`DataSet`模块之间,还存在\n",
"\n",
@ -43,13 +35,13 @@
"\n",
"&emsp; 分别存储在`datasets`和`vocabs`两个变量中,所以了解`databundle`数据包之前\n",
"\n",
"&emsp; 需要首先**复习`dataset`数据集和`vocabulary`词汇表****下面的一串代码****你知道其大概含义吗?**\n",
"需要首先**复习`dataset`数据集和`vocabulary`词汇表****下面的一串代码****你知道其大概含义吗?**\n",
"\n",
"必要提示:`NG20`,全称[`News Group 20`](http://qwone.com/~jason/20Newsgroups/)是一个新闻文本分类数据集包含20个大类以及若干小类\n",
"<!-- 必要提示:`NG20`,全称[`News Group 20`](http://qwone.com/~jason/20Newsgroups/)是一个新闻文本分类数据集包含20个类别\n",
"\n",
"&emsp; 数据集包含训练集`'ng20_train.csv'`和测试集`'ng20_test.csv'`两部分,每条数据\n",
"\n",
"&emsp; 包括`'label'`标签和`'text'`文本两个条目,通过`sample(frac=1)[:10]`随机采样并读取前十条"
"&emsp; 包括`'label'`标签和`'text'`文本两个条目,通过`sample(frac=1)[:6]`随机采样并读取前6条 -->"
]
},
{
@ -78,7 +70,171 @@
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/10 [00:00<?, ?it/s]"
"Processing: 0%| | 0/6 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------------------------------+----------+\n",
"| text | label |\n",
"+------------------------------------------+----------+\n",
"| ['this', 'quiet', ',', 'introspective... | positive |\n",
"| ['a', 'comedy-drama', 'of', 'nearly',... | positive |\n",
"| ['a', 'positively', 'thrilling', 'com... | neutral |\n",
"| ['a', 'series', 'of', 'escapades', 'd... | negative |\n",
"+------------------------------------------+----------+\n",
"+------------------------------------------+----------+\n",
"| text | label |\n",
"+------------------------------------------+----------+\n",
"| ['even', 'fans', 'of', 'ismail', 'mer... | negative |\n",
"| ['the', 'importance', 'of', 'being', ... | neutral |\n",
"+------------------------------------------+----------+\n",
"{'<pad>': 0, '<unk>': 1, 'positive': 2, 'neutral': 3, 'negative': 4}\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"from fastNLP import DataSet\n",
"from fastNLP import Vocabulary\n",
"from fastNLP.io import DataBundle\n",
"\n",
"datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
"datasets.rename_field('Sentence', 'text')\n",
"datasets.rename_field('Sentiment', 'label')\n",
"datasets.apply_more(lambda ins:{'label': ins['label'].lower(), \n",
" 'text': ins['text'].lower().split()},\n",
" progress_bar='tqdm')\n",
"datasets.delete_field('SentenceId')\n",
"train_ds, test_ds = datasets.split(ratio=0.7)\n",
"datasets = {'train': train_ds, 'test': test_ds}\n",
"print(datasets['train'])\n",
"print(datasets['test'])\n",
"\n",
"vocabs = {}\n",
"vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n",
"vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n",
"print(vocabs['label'].word2idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- 上述代码的含义是:随机读取`NG20`数据集中的各十条训练数据和测试数据,将标签都设为小写,对文本进行分词\n",
" -->\n",
"上述代码的含义是:从`test4dataset`的 6 条数据中,划分 4 条训练集(`int(6*0.7) = 4`2 条测试集\n",
"\n",
"&emsp; &emsp; 修改相关字段名称,删除序号字段,同时将标签都设为小写,对文本进行分词\n",
"\n",
"&emsp; 接着通过`concat`方法拼接测试集训练集,注意设置`inplace=False`,生成临时的新数据集\n",
"\n",
"&emsp; 使用`from_dataset`方法从拼接的数据集中抽取词汇表,为将数据集中的单词替换为序号做准备\n",
"\n",
"由此就可以得到**数据集字典`datasets`****对应训练集、测试集**)和**词汇表字典`vocabs`****对应数据集各字段**\n",
"\n",
"&emsp; 然后就可以初始化`databundle`了,通过`print`可以观察其大致结构,效果如下"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 2 datasets:\n",
"\ttrain has 4 instances.\n",
"\ttest has 2 instances.\n",
"In total 2 vocabs:\n",
"\tlabel has 5 entries.\n",
"\ttext has 96 entries.\n",
"\n",
"['train', 'test']\n",
"['label', 'text']\n"
]
}
],
"source": [
"data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n",
"print(data_bundle)\n",
"print(data_bundle.get_dataset_names())\n",
"print(data_bundle.get_vocab_names())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"此外,也可以通过`data_bundle`的`num_dataset`和`num_vocab`返回数据表和词汇表个数\n",
"\n",
"&emsp; 通过`data_bundle`的`iter_datasets`和`iter_vocabs`遍历数据表和词汇表"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 2 datasets:\n",
"\ttrain has 4 instances.\n",
"\ttest has 2 instances.\n",
"In total 2 datasets:\n",
"\tlabel has 5 entries.\n",
"\ttext has 96 entries.\n"
]
}
],
"source": [
"print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
"for name, dataset in data_bundle.iter_datasets():\n",
" print(\"\\t%s has %d instances.\" % (name, len(dataset)))\n",
"print(\"In total %d datasets:\" % data_bundle.num_dataset)\n",
"for name, vocab in data_bundle.iter_vocabs():\n",
" print(\"\\t%s has %d entries.\" % (name, len(vocab)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在数据包`databundle`中,也有和数据集`dataset`类似的四个`apply`函数,即\n",
"\n",
"&emsp; `apply`函数、`apply_field`函数、`apply_field_more`函数和`apply_more`函数\n",
"\n",
"&emsp; 负责对数据集进行预处理,如下所示是`apply_more`函数的示例,其他函数类似\n",
"\n",
"此外,通过`get_dataset`函数,可以通过数据表名`name`称找到对应数据表\n",
"\n",
"&emsp; 通过`get_vocab`函数,可以通过词汇表名`field_name`称找到对应词汇表"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -92,7 +248,7 @@
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/10 [00:00<?, ?it/s]"
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
@ -102,93 +258,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"+-------+------------------------------------------+\n",
"| label | text |\n",
"+-------+------------------------------------------+\n",
"| talk | ['mwilson', 'ncratl', 'atlantaga', 'n... |\n",
"| talk | ['ch981', 'cleveland', 'freenet', 'ed... |\n",
"| rec | ['mbeaving', 'bnr', 'ca', '\\\\(', 'bea... |\n",
"| soc | ['jayne', 'mmalt', 'guild', 'org', '\\... |\n",
"| talk | ['jrutledg', 'cs', 'ulowell', 'edu', ... |\n",
"| talk | ['cramer', 'optilink', 'com', '\\\\(', ... |\n",
"| comp | ['triton', 'unm', 'edu', '\\\\(', 'larr... |\n",
"| rec | ['ingres', 'com', '\\\\(', 'bruce', '\\\\... |\n",
"| comp | ['ldo', 'waikato', 'ac', 'nz', '\\\\(',... |\n",
"| misc | ['rebecca', 'rpi', 'edu', '\\\\(', 'ezr... |\n",
"+-------+------------------------------------------+\n",
"{'<pad>': 0, '<unk>': 1, 'rec': 2, 'talk': 3, 'comp': 4, 'soc': 5, 'misc': 6, 'sci': 7}\n"
"+------------------------------+----------+-----+\n",
"| text | label | len |\n",
"+------------------------------+----------+-----+\n",
"| ['this', 'quiet', ',', 'i... | positive | 11 |\n",
"| ['a', 'comedy-drama', 'of... | positive | 19 |\n",
"| ['a', 'positively', 'thri... | neutral | 26 |\n",
"| ['a', 'series', 'of', 'es... | negative | 37 |\n",
"+------------------------------+----------+-----+\n"
]
}
],
"source": [
"import pandas as pd\n",
"\n",
"from fastNLP import DataSet\n",
"from fastNLP import Vocabulary\n",
"from fastNLP.io import DataBundle\n",
"\n",
"datasets = {}\n",
"datasets['train'] = DataSet.from_pandas(pd.read_csv('./data/ng20_train.csv').sample(frac=1)[:10])\n",
"datasets['train'].apply_more(lambda ins:{'label': ins['label'].lower().split('.')[0], \n",
" 'text': ins['text'].lower().split()},\n",
" progress_bar='tqdm')\n",
"datasets['test'] = DataSet.from_pandas(pd.read_csv('./data/ng20_test.csv').sample(frac=1)[:10])\n",
"datasets['test'].apply_more(lambda ins:{'label': ins['label'].lower().split('.')[0], \n",
" 'text': ins['text'].lower().split()},\n",
" progress_bar='tqdm')\n",
"print(datasets['train'])\n",
"\n",
"vocabs = {}\n",
"vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n",
"vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n",
"print(vocabs['label'].word2idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"数据集比如分开的训练集、验证集和测试集以及各个field对应的vocabulary。\n",
" 该对象一般由fastNLP中各种Loader的load函数生成可以通过以下的方法获取里面的内容"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 2 datasets:\n",
"\ttrain has 10 instances.\n",
"\ttest has 10 instances.\n",
"In total 2 vocabs:\n",
"\tlabel has 8 entries.\n",
"\ttext has 1687 entries.\n",
"\n"
]
}
],
"source": [
"data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n",
"print(data_bundle)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.2 dataloader 的结构与使用"
"data_bundle.apply_more(lambda ins:{'len': len(ins['text'])}, progress_bar='tqdm')\n",
"print(data_bundle.get_dataset('train'))"
]
},
{
@ -197,14 +280,9 @@
"source": [
"## 2. fastNLP 中的 tokenizer\n",
"\n",
"### 2.1 传统 GloVe 词嵌入的加载"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2 PreTrainTokenizer 的提出\n",
"### 2.1 PreTrainTokenizer 的提出\n",
"\n",
"为什么要放弃传统的GloVe词嵌入\n",
"\n",
"在`fastNLP 0.8`中,**使用`PreTrainedTokenizer`模块来为数据集中的词语进行词向量的标注**\n",
"\n",
@ -223,7 +301,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 BertTokenizer 的基本使用\n",
"### 2.2 BertTokenizer 的基本使用\n",
"\n",
"在`fastNLP 0.8`中,以`PreTrainedTokenizer`为基类,泛化出多个子类,实现基于`BERT`等模型的标注\n",
"\n",
@ -231,7 +309,7 @@
"\n",
"**`BertTokenizer`的初始化包括 导入模块和导入数据 两步**,先通过从`fastNLP.transformers.torch`中\n",
"\n",
"&emsp; 导入`BertTokenizer`模块,再通过`from_pretrained`方法指定`tokenizer`参数类型下载\n",
"&emsp; 导入`BertTokenizer`模块,再**通过`from_pretrained`方法指定`tokenizer`参数类型下载**\n",
"\n",
"&emsp; 其中,**`'bert-base-uncased'`指定`tokenizer`使用的预训练`BERT`类型**:单词不区分大小写\n",
"\n",
@ -242,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {
"scrolled": false
},
@ -254,57 +332,390 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"dir(tokenizer)"
"通过变量`vocab_size`和`vocab_files_names`可以查看`BertTokenizer`的词汇表的大小和对应文件\n",
"\n",
"&emsp; 通过变量`vocab`可以访问`BertTokenizer`预训练的词汇表(由于内容过大就不演示了"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30522 {'vocab_file': 'vocab.txt'}\n"
]
}
],
"source": [
"print(tokenizer.vocab_size, tokenizer.vocab_files_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 实例NG20 数据集的完整加载过程\n",
"通过变量`all_special_tokens`或通过变量`special_tokens_map`可以**查看`BertTokenizer`内置的特殊词素**\n",
"\n",
"### 3.1 使用 BertTokenizer 处理数据集\n",
"&emsp; 包括**未知符`'[UNK]'`**, **断句符`'[SEP]'`**, **补零符`'[PAD]'`**, **分类符`'[CLS]'`**, **掩模`'[MASK]'`**\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
"通过变量`all_special_ids`可以**查看`BertTokenizer`内置的特殊词素对应的词汇表编号**,相同功能\n",
"\n",
"&emsp; 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
"\n",
"在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
"\n",
"&emsp; 非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`"
"&emsp; 也可以直接通过查看`pad_token`,值为`'[UNK]'`,和`pad_token_id`,值为`0`,等变量来实现"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pad_token [PAD] 0\n",
"unk_token [UNK] 100\n",
"cls_token [CLS] 101\n",
"sep_token [SEP] 102\n",
"msk_token [MASK] 103\n",
"all_tokens ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 102, 0, 101, 103]\n",
"{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
]
}
],
"source": [
"import pandas as pd\n",
"print('pad_token', tokenizer.pad_token, tokenizer.pad_token_id) \n",
"print('unk_token', tokenizer.unk_token, tokenizer.unk_token_id) \n",
"print('cls_token', tokenizer.cls_token, tokenizer.cls_token_id) \n",
"print('sep_token', tokenizer.sep_token, tokenizer.sep_token_id)\n",
"print('msk_token', tokenizer.mask_token, tokenizer.mask_token_id)\n",
"print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
"print(tokenizer.special_tokens_map)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"此外,还可以添加其他特殊字符,例如起始符`[BOS]`、终止符`[EOS]`,添加后词汇表编号也会相应改变\n",
"\n",
"from fastNLP import DataSet\n",
"from fastNLP import Vocabulary\n",
"\n",
"dataset = DataSet.from_pandas(pd.read_csv('./data/ng20_test.csv'))"
"&emsp; *但是如何添加这两个之外的字符,并且如何将这两个的编号设置为`[UNK]`之外的编号???*"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bos_token [BOS] 100\n",
"eos_token [EOS] 100\n",
"all_tokens ['[BOS]', '[EOS]', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 100, 100, 102, 0, 101, 103]\n",
"{'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n"
]
}
],
"source": [
"tokenizer.bos_token = '[BOS]'\n",
"tokenizer.eos_token = '[EOS]'\n",
"# tokenizer.bos_token_id = 104\n",
"# tokenizer.eos_token_id = 105\n",
"print('bos_token', tokenizer.bos_token, tokenizer.bos_token_id)\n",
"print('eos_token', tokenizer.eos_token, tokenizer.eos_token_id)\n",
"print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n",
"print(tokenizer.special_tokens_map)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在`BertTokenizer`中,**使用`tokenize`函数和`convert_tokens_to_string`函数可以实现文本和词素列表的互转**\n",
"\n",
"&emsp; 此外,**使用`convert_tokens_to_ids`函数和`convert_ids_to_tokens`函数则可以实现词素和词素编号的互转**\n",
"\n",
"&emsp; 上述四个函数的使用效果如下所示,此处可以明显看出,`tokenizer`分词和传统分词的不同效果,例如`'##cap'`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012]\n",
"['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
"a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .\n"
]
}
],
"source": [
"text = \"a series of escapades demonstrating the adage that what is \" \\\n",
" \"good for the goose is also good for the gander , some of which \" \\\n",
" \"occasionally amuses but none of which amounts to much of a story .\" \n",
"tks = ['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', \n",
" 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', \n",
" 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', \n",
" 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', \n",
" 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n",
"ids = [ 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, \n",
" 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204,\n",
" 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572,\n",
" 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037,\n",
" 2466, 1012]\n",
"\n",
"tokens = tokenizer.tokenize(text)\n",
"print(tokenizer.convert_tokens_to_ids(tokens))\n",
"\n",
"ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"print(tokenizer.convert_ids_to_tokens(ids))\n",
"\n",
"print(tokenizer.convert_tokens_to_string(tokens))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在`BertTokenizer`中,还有另外两个函数可以实现分词标注,分别是 **`encode`和`decode`函数****可以直接实现**\n",
"\n",
"&emsp; **文本字符串和词素编号列表的互转**,但是编码过程中会按照`BERT`的规则,**在句子首末加入`[CLS]`和`[SEP]`**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012, 102]\n",
"[CLS] a series of escapades demonstrating the adage that what is good for the goose is also good for the gander, some of which occasionally amuses but none of which amounts to much of a story. [SEP]\n"
]
}
],
"source": [
"enc = tokenizer.encode(text)\n",
"print(tokenizer.encode(text))\n",
"dec = tokenizer.decode(enc)\n",
"print(tokenizer.decode(enc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在`encode`函数之上,还有`encode_plus`函数,这也是在数据预处理中,`BertTokenizer`模块最常用到的函数\n",
"\n",
"&emsp; **`encode`函数的参数****`encode_plus`函数都有****`encode`函数词素编号列表****`encode_plus`函数返回字典**\n",
"\n",
"在`encode_plus`函数的返回值中,字段`input_ids`表示词素编号,其余两个字段后文有详细解释\n",
"\n",
"&emsp; **字段`token_type_ids`详见`text_pairs`的示例****字段`attention_mask`详见`batch_text`的示例**\n",
"\n",
"在`encode_plus`函数的参数中,参数`add_special_tokens`表示是否按照`BERT`的规则,加入相关特殊字符\n",
"\n",
"&emsp; 参数`max_length`表示句子截取最大长度(算特殊字符),在参数`truncation=True`时会自动截取\n",
"\n",
"&emsp; 参数`return_attention_mask`约定返回的字典中是否包括`attention_mask`字段,以上案例如下"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
]
}
],
"source": [
"text = \"a series of escapades demonstrating the adage that what is good for the goose is also good for \"\\\n",
" \"the gander , some of which occasionally amuses but none of which amounts to much of a story .\" \n",
"\n",
"encoded = tokenizer.encode_plus(text=text, add_special_tokens=True, max_length=32, \n",
" truncation=True, return_attention_mask=True)\n",
"print(encoded)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在`encode_plus`函数之上,还有`batch_encode_plus`函数(类似地,在`decode`之上,还有`batch_decode`\n",
"\n",
"&emsp; 两者参数类似,**`batch_encode_plus`函数针对批量文本`batch_text`****或者批量句对`text_pairs`**\n",
"\n",
"在针对批量文本`batch_text`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
"\n",
"&emsp; 可以发现,**`attention_mask`字段通过`01`标注出词素序列中该位置是否为补零**,可以用做自注意力的掩模"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 0, 0], [101, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 0, 0, 0, 0, 0, 0, 0], [101, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]}\n"
]
}
],
"source": [
"batch_text = [\"a series of escapades demonstrating the adage that\",\n",
" \"what is good for the goose is also good for the gander\",\n",
" \"some of which occasionally amuses\",\n",
" \"but none of which amounts to much of a story\" ]\n",
"\n",
"encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text, padding=True,\n",
" add_special_tokens=True, max_length=16, truncation=True, \n",
" return_attention_mask=True)\n",
"print(encoded)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"而在针对批量句对`text_pairs`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n",
"\n",
"&emsp; 可以发现,**`token_type_ids`字段通过`01`标注出词素序列中该位置为句对中的第几句**,句对用`[SEP]`分割"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]}\n"
]
}
],
"source": [
"text_pairs = [(\"a series of escapades demonstrating the adage that\",\n",
" \"what is good for the goose is also good for the gander\"),\n",
" (\"some of which occasionally amuses\",\n",
" \"but none of which amounts to much of a story\")]\n",
"\n",
"encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=text_pairs, padding=True,\n",
" add_special_tokens=True, max_length=32, truncation=True, \n",
" return_attention_mask=True)\n",
"print(encoded)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"回到`encode_plus`上,在接下来的示例中,**使用内置的`functools.partial`模块构造`encode`函数**\n",
"\n",
"&emsp; 接着**使用该函数对`databundle`进行数据预处理**,由于`tokenizer.encode_plus`返回的是一个字典\n",
"\n",
"&emsp; 读入的是一个字段,所以此处使用`apply_field_more`方法,得到结果自动并入`databundle`中如下"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"functools.partial(<bound method PreTrainedTokenizerBase.encode_plus of PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})>, max_length=32, truncation=True, return_attention_mask=True)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+----------+-----+------------------+--------------------+--------------------+\n",
"| text | label | len | input_ids | token_type_ids | attention_mask |\n",
"+------------------+----------+-----+------------------+--------------------+--------------------+\n",
"| ['this', 'qui... | positive | 11 | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| ['a', 'comedy... | positive | 19 | [101, 1037, 1... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| ['a', 'positi... | neutral | 26 | [101, 1037, 1... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| ['a', 'series... | negative | 37 | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"+------------------+----------+-----+------------------+--------------------+--------------------+\n"
]
}
],
"source": [
"from functools import partial\n",
"\n",
"encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
"encode = partial(tokenizer.encode_plus, max_length=32, truncation=True,\n",
" return_attention_mask=True)\n",
"# 会新增 input_ids 、 attention_mask 和 token_type_ids 这三个 field\n",
"dataset.apply_field_more(encode, field_name='text')"
"print(encode)\n",
"\n",
"data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n",
"print(data_bundle.datasets['train'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### 2.3 补充GloVe 词嵌入的使用\n",
"\n",
"如何使用传统的GloVe词嵌入"
]
},
{
@ -312,16 +723,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_vocab = Vocabulary(padding=None, unknown=None)\n",
"\n",
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
" new_field_name='labels')\n",
"# 需要将 input_ids 的 pad 值设置为 tokenizer 的 pad 值\n",
"dataset.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
"dataset.set_ignore('label', 'text') # 因为 label 是原始的不需要的 str ,所以我们可以忽略它,让它不要在 batch 的输出中出现"
]
"source": []
}
],
"metadata": {
@ -341,6 +743,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,

View File

@ -5,7 +5,7 @@
"id": "213d538c",
"metadata": {},
"source": [
"# T3. \n",
"# T3. dataloader 的内部结构和基本使用\n",
"\n",
"&emsp; 1 &ensp; \n",
" \n",
@ -26,6 +26,206 @@
"&emsp; &emsp; 3.2 &ensp; "
]
},
{
"cell_type": "markdown",
"id": "d74d0523",
"metadata": {},
"source": [
"## 3. fastNLP 中的 dataloader\n",
"\n",
"### 3.1 collator 的概念与使用\n",
"\n",
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n",
"\n",
"&emsp; 进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n",
"\n",
"&emsp; 本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n",
"\n",
"在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aca72b49",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from fastNLP import Collator\n",
"\n",
"collator = Collator()\n",
"# collator.set_pad(field_name='text', pad_val='<pad>')"
]
},
{
"cell_type": "markdown",
"id": "51bf0878",
"metadata": {},
"source": [
"&emsp; "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fd2486f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "f9bbd9a7",
"metadata": {},
"source": [
"### 3.2 dataloader 的结构与使用"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "651baef6",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from fastNLP import prepare_torch_dataloader\n",
"\n",
"dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n",
"\n",
"print(type(dl_bundle), type(dl_bundle['train']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "726ba357",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n",
"print(type(dataloader))\n",
"print(dir(dataloader))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0795b3e",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataloader.collate_fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0c3c58d",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"dataloader.batch_sampler"
]
},
{
"cell_type": "markdown",
"id": "7ed431cc",
"metadata": {},
"source": [
"### 3.3 实例NG20 的加载预处理\n",
"\n",
"在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n",
"\n",
"&emsp; 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n",
"\n",
"在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n",
"\n",
"&emsp; 非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a89ef613",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from fastNLP import DataSet\n",
"from fastNLP import Vocabulary\n",
"\n",
"dataset = DataSet.from_pandas(pd.read_csv('./data/ng20_test.csv'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1624b0fa",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
" return_attention_mask=True)\n",
"# 会新增 input_ids 、 attention_mask 和 token_type_ids 这三个 field\n",
"dataset.apply_field_more(encode, field_name='text')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0991a8ee",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"target_vocab = Vocabulary(padding=None, unknown=None)\n",
"\n",
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
" new_field_name='labels')\n",
"# 需要将 input_ids 的 pad 值设置为 tokenizer 的 pad 值\n",
"dataset.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
"dataset.set_ignore('label', 'text') # 因为 label 是原始的不需要的 str ,所以我们可以忽略它,让它不要在 batch 的输出中出现"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -52,6 +252,15 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,