mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
261 lines
11 KiB
Plaintext
261 lines
11 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# BertEmbedding的各种用法\n",
|
||
"Bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。\n",
|
||
"\n",
|
||
"为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息\n",
|
||
"\n",
|
||
"\n",
|
||
"下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。\n",
|
||
"\n",
|
||
"## 1. 使用Bert进行文本分类\n",
|
||
"\n",
|
||
"文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类\n",
|
||
"\n",
|
||
" *1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!*\n",
|
||
"\n",
|
||
"这里我们使用fastNLP提供自动下载的微博分类进行测试"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from fastNLP.io import WeiboSenti100kPipe\n",
|
||
"from fastNLP.embeddings import BertEmbedding\n",
|
||
"from fastNLP.models import BertForSequenceClassification\n",
|
||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"data_bundle =WeiboSenti100kPipe().process_from_file()\n",
|
||
"data_bundle.rename_field('chars', 'words')\n",
|
||
"\n",
|
||
"# 载入BertEmbedding\n",
|
||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
|
||
"\n",
|
||
"# 载入模型\n",
|
||
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))\n",
|
||
"\n",
|
||
"# 训练模型\n",
|
||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
|
||
" optimizer=Adam(model_params=model.parameters(), lr=2e-5),\n",
|
||
" loss=CrossEntropyLoss(), device=device,\n",
|
||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
|
||
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
|
||
"trainer.train()\n",
|
||
"\n",
|
||
"# 测试结果\n",
|
||
"from fastNLP import Tester\n",
|
||
"\n",
|
||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
|
||
"tester.test()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. 使用Bert进行命名实体识别\n",
|
||
"\n",
|
||
"命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 两句话,例如下面的例子\n",
|
||
"\n",
|
||
"```\n",
|
||
" 中 B-ORG\n",
|
||
" 共 I-ORG\n",
|
||
" 中 I-ORG\n",
|
||
" 央 I-ORG\n",
|
||
" 致 O\n",
|
||
" 中 B-ORG\n",
|
||
" 国 I-ORG\n",
|
||
" 致 I-ORG\n",
|
||
" 公 I-ORG\n",
|
||
" 党 I-ORG\n",
|
||
" 十 I-ORG\n",
|
||
" 一 I-ORG\n",
|
||
" 大 I-ORG\n",
|
||
" 的 O\n",
|
||
" 贺 O\n",
|
||
" 词 O\n",
|
||
"```\n",
|
||
"\n",
|
||
"这部分内容请参考 快速实现序列标注模型\n",
|
||
"\n",
|
||
"## 3. 使用Bert进行文本匹配\n",
|
||
"\n",
|
||
"文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from fastNLP.io import CNXNLIBertPipe\n",
|
||
"from fastNLP.embeddings import BertEmbedding\n",
|
||
"from fastNLP.models import BertForSentenceMatching\n",
|
||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
|
||
"from fastNLP.core.optimizer import AdamW\n",
|
||
"from fastNLP.core.callback import WarmupCallback\n",
|
||
"from fastNLP import Tester\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"data_bundle = CNXNLIBertPipe().process_from_file()\n",
|
||
"data_bundle.rename_field('chars', 'words')\n",
|
||
"print(data_bundle)\n",
|
||
"\n",
|
||
"# 载入BertEmbedding\n",
|
||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
|
||
"\n",
|
||
"# 载入模型\n",
|
||
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))\n",
|
||
"\n",
|
||
"# 训练模型\n",
|
||
"callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]\n",
|
||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
|
||
" optimizer=AdamW(params=model.parameters(), lr=4e-5),\n",
|
||
" loss=CrossEntropyLoss(), device=device,\n",
|
||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
|
||
" metrics=AccuracyMetric(), n_epochs=5, print_every=1,\n",
|
||
" update_every=8, callbacks=callbacks)\n",
|
||
"trainer.train()\n",
|
||
"\n",
|
||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())\n",
|
||
"tester.test()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. 使用Bert进行中文问答\n",
|
||
"\n",
|
||
"问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 例如:\n",
|
||
"\n",
|
||
"```\n",
|
||
"\"context\": \"锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常\n",
|
||
"用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及\n",
|
||
"作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合\n",
|
||
"相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单\n",
|
||
"皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大\n",
|
||
"钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师\n",
|
||
"傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼\n",
|
||
"和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:\",\n",
|
||
"\"question\": \"锣鼓经是什么?\",\n",
|
||
"\"answers\": [\n",
|
||
" {\n",
|
||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||
" \"answer_start\": 4\n",
|
||
" },\n",
|
||
" {\n",
|
||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||
" \"answer_start\": 4\n",
|
||
" },\n",
|
||
" {\n",
|
||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||
" \"answer_start\": 4\n",
|
||
" }\n",
|
||
"]\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"您可以通过以下的代码训练 (原文代码:[CMRC2018](https://github.com/ymcui/cmrc2018) )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from fastNLP.embeddings import BertEmbedding\n",
|
||
"from fastNLP.models import BertForQuestionAnswering\n",
|
||
"from fastNLP.core.losses import CMRC2018Loss\n",
|
||
"from fastNLP.core.metrics import CMRC2018Metric\n",
|
||
"from fastNLP.io.pipe.qa import CMRC2018BertPipe\n",
|
||
"from fastNLP import Trainer, BucketSampler\n",
|
||
"from fastNLP import WarmupCallback, GradientClipCallback\n",
|
||
"from fastNLP.core.optimizer import AdamW\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"data_bundle = CMRC2018BertPipe().process_from_file()\n",
|
||
"data_bundle.rename_field('chars', 'words')\n",
|
||
"\n",
|
||
"print(data_bundle)\n",
|
||
"\n",
|
||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,\n",
|
||
" dropout=0.5, word_dropout=0.01)\n",
|
||
"model = BertForQuestionAnswering(embed)\n",
|
||
"loss = CMRC2018Loss()\n",
|
||
"metric = CMRC2018Metric()\n",
|
||
"\n",
|
||
"wm_callback = WarmupCallback(schedule='linear')\n",
|
||
"gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')\n",
|
||
"callbacks = [wm_callback, gc_callback]\n",
|
||
"\n",
|
||
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
|
||
"\n",
|
||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
|
||
" sampler=BucketSampler(seq_len_field_name='context_len'),\n",
|
||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric,\n",
|
||
" callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,\n",
|
||
" test_use_tqdm=False, update_every=10)\n",
|
||
"trainer.train(load_best_model=False)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"训练结果(和论文中报道的基本一致):\n",
|
||
"\n",
|
||
"```\n",
|
||
" In Epoch:2/Step:1692, got best dev performance:\n",
|
||
" CMRC2018Metric: f1=85.61, em=66.08\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python Now",
|
||
"language": "python",
|
||
"name": "now"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|