更新了 bert_embedding 的文档和代码

This commit is contained in:
ChenXin 2020-02-28 11:16:09 +08:00
parent eb74379fd9
commit 27d449fc23
3 changed files with 288 additions and 494 deletions

View File

@ -15,6 +15,10 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
----------------------------------
下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。
.. note::
本教程必须使用 GPU 进行实验,并且会花费大量的时间
1. 使用Bert进行文本分类
----------------------------------
文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类
@ -28,26 +32,25 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
.. code-block:: python
from fastNLP.io import WeiboSenti100kPipe
from fastNLP.embeddings import BertEmbedding
from fastNLP.models import BertForSequenceClassification
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
import torch
data_bundle =WeiboSenti100kPipe().process_from_file()
data_bundle.rename_field('chars', 'words')
# 载入BertEmbedding
from fastNLP.embeddings import BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
# 载入模型
from fastNLP.models import BertForSequenceClassification
model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))
# 训练模型
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model,
optimizer=Adam(model_params=model.parameters(), lr=2e-5),
loss=CrossEntropyLoss(), device=0,
loss=CrossEntropyLoss(), device=device,
batch_size=8, dev_data=data_bundle.get_dataset('dev'),
metrics=AccuracyMetric(), n_epochs=2, print_every=1)
trainer.train()
@ -92,7 +95,7 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
贺 O
词 O
这部分内容请参考 :doc:`快速实现序列标注模型 </tutorials/tutorial_9_seq_labeling>`
这部分内容请参考 :doc:`/tutorials/序列标注`
3. 使用Bert进行文本匹配
@ -102,36 +105,36 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
.. code-block:: python
data_bundle = CNXNLIBertPipe().process_from_file(paths)
from fastNLP.io import CNXNLIBertPipe
from fastNLP.embeddings import BertEmbedding
from fastNLP.models import BertForSentenceMatching
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
from fastNLP.core.optimizer import AdamW
from fastNLP.core.callback import WarmupCallback
from fastNLP import Tester
import torch
data_bundle = CNXNLIBertPipe().process_from_file()
data_bundle.rename_field('chars', 'words')
print(data_bundle)
# 载入BertEmbedding
from fastNLP.embeddings import BertEmbedding
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)
# 载入模型
from fastNLP.models import BertForSentenceMatching
model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))
# 训练模型
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
from fastNLP.core.optimizer import AdamW
from fastNLP.core.callback import WarmupCallback
callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]
device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model,
optimizer=AdamW(params=model.parameters(), lr=4e-5),
loss=CrossEntropyLoss(), device=0,
loss=CrossEntropyLoss(), device=device,
batch_size=8, dev_data=data_bundle.get_dataset('dev'),
metrics=AccuracyMetric(), n_epochs=5, print_every=1,
update_every=8, callbacks=callbacks)
trainer.train()
from fastNLP import Tester
tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())
tester.test()
@ -174,7 +177,7 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
}
]
您可以通过以下的代码训练 `CMRC2018 <https://github.com/ymcui/cmrc2018>`_
您可以通过以下的代码训练 (原文代码:`CMRC2018 <https://github.com/ymcui/cmrc2018>`_)
.. code-block:: python
@ -186,7 +189,7 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
from fastNLP import Trainer, BucketSampler
from fastNLP import WarmupCallback, GradientClipCallback
from fastNLP.core.optimizer import AdamW
import torch
data_bundle = CMRC2018BertPipe().process_from_file()
data_bundle.rename_field('chars', 'words')
@ -205,14 +208,15 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag
optimizer = AdamW(model.parameters(), lr=5e-5)
device = 0 if torch.cuda.is_available() else 'cpu'
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,
sampler=BucketSampler(seq_len_field_name='context_len'),
dev_data=data_bundle.get_dataset('dev'), metrics=metric,
callbacks=callbacks, device=0, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,
test_use_tqdm=False, update_every=10)
trainer.train(load_best_model=False)
训练结果(和论文中报道的基本一致)::
训练结果(和论文中报道的基本一致)::
In Epoch:2/Step:1692, got best dev performance:
CMRC2018Metric: f1=85.61, em=66.08

View File

@ -1,470 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BertEmbedding的各种用法\n",
"fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础是一个使用BERT对words进行编码的Embedding。\n",
"\n",
"使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。\n",
"\n",
"*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*\n",
"\n",
"## 1. BERT for Squence Classification\n",
"在文本分类任务中我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"import torch\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"In total 3 datasets:\n",
"\ttest has 2210 instances.\n",
"\ttrain has 8544 instances.\n",
"\tdev has 1101 instances.\n",
"In total 2 vocabs:\n",
"\twords has 21701 entries.\n",
"\ttarget has 5 entries."
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 载入数据集\n",
"from fastNLP.io import SSTPipe\n",
"data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()\n",
"data_bundle"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
"Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
"Start to generate word pieces for word.\n",
"Found(Or segment into word pieces) 21701 words out of 21701.\n"
]
}
],
"source": [
"# 载入BertEmbedding\n",
"from fastNLP.embeddings import BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 载入模型\n",
"from fastNLP.models import BertForSequenceClassification\n",
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2019-09-11-17-35-26\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 2.08 seconds!\n",
"Evaluation on dev at Epoch 1/2. Step:134/268: \n",
"AccuracyMetric: acc=0.459582\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 2.2 seconds!\n",
"Evaluation on dev at Epoch 2/2. Step:268/268: \n",
"AccuracyMetric: acc=0.468665\n",
"\n",
"\n",
"In Epoch:2/Step:268, got best dev performance:\n",
"AccuracyMetric: acc=0.468665\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 268,\n",
" 'seconds': 114.5}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练模型\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
" optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
" loss=CrossEntropyLoss(), device=[0],\n",
" batch_size=64, dev_data=data_bundle.get_dataset('dev'), \n",
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 4.52 seconds!\n",
"[tester] \n",
"AccuracyMetric: acc=0.504072\n"
]
},
{
"data": {
"text/plain": [
"{'AccuracyMetric': {'acc': 0.504072}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 测试结果并删除模型\n",
"from fastNLP import Tester\n",
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
"tester.test()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 2. BERT for Sentence Matching\n",
"在Matching任务中我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"In total 3 datasets:\n",
"\ttest has 3000 instances.\n",
"\ttrain has 2490 instances.\n",
"\tdev has 277 instances.\n",
"In total 2 vocabs:\n",
"\twords has 41281 entries.\n",
"\ttarget has 2 entries."
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 载入数据集\n",
"from fastNLP.io import RTEBertPipe\n",
"data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()\n",
"data_bundle"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
"Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
"Start to generate word pieces for word.\n",
"Found(Or segment into word pieces) 41279 words out of 41281.\n"
]
}
],
"source": [
"# 载入BertEmbedding\n",
"from fastNLP.embeddings import BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# 载入模型\n",
"from fastNLP.models import BertForSentenceMatching\n",
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2019-09-11-17-37-36\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 1.72 seconds!\n",
"Evaluation on dev at Epoch 1/2. Step:156/312: \n",
"AccuracyMetric: acc=0.624549\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 1.74 seconds!\n",
"Evaluation on dev at Epoch 2/2. Step:312/312: \n",
"AccuracyMetric: acc=0.649819\n",
"\n",
"\n",
"In Epoch:2/Step:312, got best dev performance:\n",
"AccuracyMetric: acc=0.649819\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 312,\n",
" 'seconds': 109.87}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练模型\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
" optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
" loss=CrossEntropyLoss(), device=[0],\n",
" batch_size=16, dev_data=data_bundle.get_dataset('dev'), \n",
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,260 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BertEmbedding的各种用法\n",
"Bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后因其性能卓越受到了极大的关注在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。\n",
"\n",
"为了方便大家的使用fastNLP提供了预训练的Embedding权重及数据集的自动下载支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息\n",
"\n",
"\n",
"下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。\n",
"\n",
"## 1. 使用Bert进行文本分类\n",
"\n",
"文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类\n",
"\n",
" *1, 商务大床房房间很大床有2M宽整体感觉经济实惠不错!*\n",
"\n",
"这里我们使用fastNLP提供自动下载的微博分类进行测试"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import WeiboSenti100kPipe\n",
"from fastNLP.embeddings import BertEmbedding\n",
"from fastNLP.models import BertForSequenceClassification\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"import torch\n",
"\n",
"data_bundle =WeiboSenti100kPipe().process_from_file()\n",
"data_bundle.rename_field('chars', 'words')\n",
"\n",
"# 载入BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
"\n",
"# 载入模型\n",
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))\n",
"\n",
"# 训练模型\n",
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
" optimizer=Adam(model_params=model.parameters(), lr=2e-5),\n",
" loss=CrossEntropyLoss(), device=device,\n",
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
"trainer.train()\n",
"\n",
"# 测试结果\n",
"from fastNLP import Tester\n",
"\n",
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
"tester.test()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 使用Bert进行命名实体识别\n",
"\n",
"命名实体识别是给定一句话标记出其中的实体。一般序列标注的任务都使用conll格式conll格式是至一行中通过制表符分隔不同的内容使用空行分隔 两句话,例如下面的例子\n",
"\n",
"```\n",
" 中 B-ORG\n",
" 共 I-ORG\n",
" 中 I-ORG\n",
" 央 I-ORG\n",
" 致 O\n",
" 中 B-ORG\n",
" 国 I-ORG\n",
" 致 I-ORG\n",
" 公 I-ORG\n",
" 党 I-ORG\n",
" 十 I-ORG\n",
" 一 I-ORG\n",
" 大 I-ORG\n",
" 的 O\n",
" 贺 O\n",
" 词 O\n",
"```\n",
"\n",
"这部分内容请参考 快速实现序列标注模型\n",
"\n",
"## 3. 使用Bert进行文本匹配\n",
"\n",
"文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import CNXNLIBertPipe\n",
"from fastNLP.embeddings import BertEmbedding\n",
"from fastNLP.models import BertForSentenceMatching\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"from fastNLP.core.optimizer import AdamW\n",
"from fastNLP.core.callback import WarmupCallback\n",
"from fastNLP import Tester\n",
"import torch\n",
"\n",
"data_bundle = CNXNLIBertPipe().process_from_file()\n",
"data_bundle.rename_field('chars', 'words')\n",
"print(data_bundle)\n",
"\n",
"# 载入BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
"\n",
"# 载入模型\n",
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))\n",
"\n",
"# 训练模型\n",
"callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]\n",
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
" optimizer=AdamW(params=model.parameters(), lr=4e-5),\n",
" loss=CrossEntropyLoss(), device=device,\n",
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
" metrics=AccuracyMetric(), n_epochs=5, print_every=1,\n",
" update_every=8, callbacks=callbacks)\n",
"trainer.train()\n",
"\n",
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())\n",
"tester.test()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 使用Bert进行中文问答\n",
"\n",
"问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 例如:\n",
"\n",
"```\n",
"\"context\": \"锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常\n",
"用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及\n",
"作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合\n",
"相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单\n",
"皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大\n",
"钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师\n",
"傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼\n",
"和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:\",\n",
"\"question\": \"锣鼓经是什么?\",\n",
"\"answers\": [\n",
" {\n",
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
" \"answer_start\": 4\n",
" },\n",
" {\n",
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
" \"answer_start\": 4\n",
" },\n",
" {\n",
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
" \"answer_start\": 4\n",
" }\n",
"]\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"您可以通过以下的代码训练 (原文代码:[CMRC2018](https://github.com/ymcui/cmrc2018) )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.embeddings import BertEmbedding\n",
"from fastNLP.models import BertForQuestionAnswering\n",
"from fastNLP.core.losses import CMRC2018Loss\n",
"from fastNLP.core.metrics import CMRC2018Metric\n",
"from fastNLP.io.pipe.qa import CMRC2018BertPipe\n",
"from fastNLP import Trainer, BucketSampler\n",
"from fastNLP import WarmupCallback, GradientClipCallback\n",
"from fastNLP.core.optimizer import AdamW\n",
"import torch\n",
"\n",
"data_bundle = CMRC2018BertPipe().process_from_file()\n",
"data_bundle.rename_field('chars', 'words')\n",
"\n",
"print(data_bundle)\n",
"\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,\n",
" dropout=0.5, word_dropout=0.01)\n",
"model = BertForQuestionAnswering(embed)\n",
"loss = CMRC2018Loss()\n",
"metric = CMRC2018Metric()\n",
"\n",
"wm_callback = WarmupCallback(schedule='linear')\n",
"gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')\n",
"callbacks = [wm_callback, gc_callback]\n",
"\n",
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
"\n",
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
" sampler=BucketSampler(seq_len_field_name='context_len'),\n",
" dev_data=data_bundle.get_dataset('dev'), metrics=metric,\n",
" callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,\n",
" test_use_tqdm=False, update_every=10)\n",
"trainer.train(load_best_model=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"训练结果(和论文中报道的基本一致):\n",
"\n",
"```\n",
" In Epoch:2/Step:1692, got best dev performance:\n",
" CMRC2018Metric: f1=85.61, em=66.08\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python Now",
"language": "python",
"name": "now"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}