mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 10:48:40 +08:00
!11 update packaging related
Merge pull request !11 from yh_cc/auto-1850475-master-1657210864575
This commit is contained in:
commit
4f2fd20ade
101
.Jenkinsfile
101
.Jenkinsfile
@ -1,36 +1,96 @@
|
||||
pipeline {
|
||||
agent {
|
||||
docker {
|
||||
image 'ubuntu_tester'
|
||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci'
|
||||
}
|
||||
agent any
|
||||
options {
|
||||
timeout(time:30, unit: 'MINUTES')
|
||||
}
|
||||
environment {
|
||||
TRAVIS = 1
|
||||
PJ_NAME = 'fastNLP'
|
||||
POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/14719364-818d-4f88-9057-7c9f0eaaf6ae'
|
||||
POST_URL = 'https://open.feishu.cn/open-apis/bot/v2/hook/2f7122e3-3459-43d2-a9e4-ddd77bfc4282'
|
||||
}
|
||||
stages {
|
||||
stage('Package Installation') {
|
||||
steps {
|
||||
sh 'python setup.py install'
|
||||
}
|
||||
}
|
||||
stage('Parallel Stages') {
|
||||
parallel {
|
||||
stage('Document Building') {
|
||||
stage('Test Other'){
|
||||
agent {
|
||||
docker {
|
||||
image 'fnlp:other'
|
||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'cd docs && make prod'
|
||||
sh 'rm -rf /docs/${PJ_NAME}'
|
||||
sh 'mv docs/build/html /docs/${PJ_NAME}'
|
||||
sh 'pytest ./tests --durations=0 --html=other.html --self-contained-html -m "not (torch or paddle or paddledist or jittor or oneflow or deepspeed or oneflowdist or torchpaddle or torchjittor or torchoneflow)"'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv other.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Package Testing') {
|
||||
stage('Test Torch-1.11') {
|
||||
agent {
|
||||
docker {
|
||||
image 'fnlp:torch-1.11'
|
||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pip install fitlog'
|
||||
sh 'pytest ./tests --html=test_results.html --self-contained-html'
|
||||
sh 'pytest ./tests/ --durations=0 --html=torch-1.11.html --self-contained-html -m torch'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.11.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Test Torch-1.6') {
|
||||
agent {
|
||||
docker {
|
||||
image 'fnlp:torch-1.6'
|
||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests/ --durations=0 --html=torch-1.6.html --self-contained-html -m torch'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv torch-1.6.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Test Paddle') {
|
||||
agent {
|
||||
docker {
|
||||
image 'fnlp:paddle'
|
||||
args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pytest ./tests --durations=0 --html=paddle.html --self-contained-html -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests --durations=0 --html=paddle_with_backend.html --self-contained-html -m paddle --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_dist_utils.py --durations=0 --html=paddle_dist_utils.html --self-contained-html --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/drivers/paddle_driver/test_fleet.py --durations=0 --html=paddle_fleet.html --self-contained-html --co'
|
||||
sh 'FASTNLP_BACKEND=paddle pytest ./tests/core/controllers/test_trainer_paddle.py --durations=0 --html=paddle_trainer.html --self-contained-html --co'
|
||||
}
|
||||
post {
|
||||
always {
|
||||
sh 'html_path=/ci/${PJ_NAME}/report-${BUILD_NUMBER}-${GIT_BRANCH#*/}-${GIT_COMMIT} && mkdir -p ${html_path} && mv paddle*.html ${html_path}'
|
||||
}
|
||||
}
|
||||
}
|
||||
// stage('Test Jittor') {
|
||||
// agent {
|
||||
// docker {
|
||||
// image 'fnlp:jittor'
|
||||
// args '-u root:root -v ${JENKINS_HOME}/html/docs:/docs -v ${JENKINS_HOME}/html/_ci:/ci --gpus all --shm-size 1G'
|
||||
// }
|
||||
// }
|
||||
// steps {
|
||||
// // sh 'pip install fitlog'
|
||||
// // sh 'pytest ./tests --html=test_results.html --self-contained-html'
|
||||
// sh 'pytest ./tests --durations=0 --html=jittor.html --self-contained-html -m jittor --co'
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -40,8 +100,7 @@ pipeline {
|
||||
}
|
||||
success {
|
||||
sh 'post 0'
|
||||
sh 'post github'
|
||||
// sh 'post github'
|
||||
}
|
||||
}
|
||||
|
||||
}
|
18
.gitignore
vendored
18
.gitignore
vendored
@ -1,18 +0,0 @@
|
||||
.gitignore
|
||||
|
||||
.DS_Store
|
||||
.ipynb_checkpoints
|
||||
*.pyc
|
||||
__pycache__
|
||||
*.swp
|
||||
.vscode/
|
||||
.idea/**
|
||||
|
||||
caches
|
||||
|
||||
# fitlog
|
||||
.fitlog
|
||||
logs/
|
||||
.fitconfig
|
||||
|
||||
docs/build
|
30
.travis.yml
30
.travis.yml
@ -1,30 +0,0 @@
|
||||
language: python
|
||||
python:
|
||||
- "3.6"
|
||||
|
||||
env:
|
||||
- TRAVIS=1
|
||||
|
||||
# command to install dependencies
|
||||
install:
|
||||
- pip install --quiet -r requirements.txt
|
||||
- pip install --quiet fitlog
|
||||
- pip install pytest>=3.6
|
||||
- pip install pytest-cov
|
||||
# command to run tests
|
||||
script:
|
||||
# - python -m spacy download en
|
||||
- pytest --cov=fastNLP tests/
|
||||
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
notifications:
|
||||
webhooks:
|
||||
urls:
|
||||
- https://open.feishu.cn/officialapp/notify/55ba4b15d04608e875c122f11484a4e2fa807c42b9ca074509bea654d1b99ca6
|
||||
on_success: always # default: always
|
||||
on_failure: always # default: always
|
||||
on_start: never # default: never
|
||||
on_cancel: always # default: always
|
||||
on_error: always # default: always
|
@ -2,6 +2,4 @@ include requirements.txt
|
||||
include LICENSE
|
||||
include README.md
|
||||
prune tests/
|
||||
prune reproduction/
|
||||
prune fastNLP/api
|
||||
prune fastNLP/automl
|
||||
prune tutorials/
|
290
README.md
290
README.md
@ -1,110 +1,239 @@
|
||||
# fastNLP
|
||||
|
||||
[![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP)
|
||||
[![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP)
|
||||
[![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP)
|
||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
|
||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
|
||||
|
||||
fastNLP是一款面向自然语言处理(NLP)的轻量级框架,目标是快速实现NLP任务以及构建复杂模型。
|
||||
[//]: # ([![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP))
|
||||
|
||||
[//]: # ([![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP))
|
||||
|
||||
[//]: # ([![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP))
|
||||
|
||||
[//]: # (![Hex.pm](https://img.shields.io/hexpm/l/plug.svg))
|
||||
|
||||
[//]: # ([![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest))
|
||||
|
||||
|
||||
fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等。
|
||||
|
||||
fastNLP具有如下的特性:
|
||||
|
||||
- 统一的Tabular式数据容器,简化数据预处理过程;
|
||||
- 内置多种数据集的Loader和Pipe,省去预处理代码;
|
||||
- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等;
|
||||
- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
|
||||
- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
|
||||
- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。
|
||||
- 便捷。在数据处理中可以通过apply函数避免循环、使用多进程提速等;在训练循环阶段可以很方便定制操作。
|
||||
- 高效。无需改动代码,实现fp16切换、多卡、ZeRO优化等。
|
||||
- 兼容。fastNLP支持多种深度学习框架作为后端。
|
||||
|
||||
> :warning: **为了实现对不同深度学习架构的兼容,fastNLP 1.0.0之后的版本重新设计了架构,因此与过去的fastNLP版本不完全兼容,
|
||||
> 基于更早的fastNLP代码需要做一定的调整**:
|
||||
|
||||
## fastNLP文档
|
||||
[中文文档](http://www.fastnlp.top/docs/fastNLP/master/index.html)
|
||||
|
||||
## 安装指南
|
||||
|
||||
fastNLP 依赖以下包:
|
||||
|
||||
+ numpy>=1.14.2
|
||||
+ torch>=1.0.0
|
||||
+ tqdm>=4.28.1
|
||||
+ nltk>=3.4.1
|
||||
+ requests
|
||||
+ spacy
|
||||
+ prettytable>=0.7.2
|
||||
|
||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
|
||||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装
|
||||
|
||||
fastNLP可以通过以下的命令进行安装
|
||||
```shell
|
||||
pip install fastNLP
|
||||
python -m spacy download en
|
||||
pip install fastNLP>=1.0.0alpha
|
||||
```
|
||||
如果需要安装更早版本的fastNLP请指定版本号,例如
|
||||
```shell
|
||||
pip install fastNLP==0.7.1
|
||||
```
|
||||
另外,请根据使用的深度学习框架,安装相应的深度学习框架。
|
||||
|
||||
<details>
|
||||
<summary>Pytorch</summary>
|
||||
下面是使用pytorch来进行文本分类的例子。需要安装torch>=1.6.0。
|
||||
|
||||
```python
|
||||
from fastNLP.io import ChnSentiCorpLoader
|
||||
from functools import partial
|
||||
from fastNLP import cache_results
|
||||
from fastNLP.transformers.torch import BertTokenizer
|
||||
|
||||
# 使用cache_results装饰器装饰函数,将prepare_data的返回结果缓存到caches/cache.pkl,再次运行时,如果
|
||||
# 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。
|
||||
@cache_results('caches/cache.pkl')
|
||||
def prepare_data():
|
||||
# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
|
||||
data_bundle = ChnSentiCorpLoader().load()
|
||||
# 使用tokenizer对数据进行tokenize
|
||||
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
|
||||
tokenize = partial(tokenizer, max_length=256) # 限制数据的最大长度
|
||||
data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
|
||||
data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
|
||||
return data_bundle
|
||||
data_bundle = prepare_data()
|
||||
print(data_bundle.get_dataset('train')[:4])
|
||||
|
||||
# 初始化model, optimizer
|
||||
from fastNLP.transformers.torch import BertForSequenceClassification
|
||||
from torch import optim
|
||||
model = BertForSequenceClassification.from_pretrained('hfl/chinese-bert-wwm')
|
||||
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
|
||||
|
||||
# 准备dataloader
|
||||
from fastNLP import prepare_dataloader
|
||||
dls = prepare_dataloader(data_bundle, batch_size=32)
|
||||
|
||||
# 准备训练
|
||||
from fastNLP import Trainer, Accuracy, LoadBestModelCallback, TorchWarmupCallback, Event
|
||||
callbacks = [
|
||||
TorchWarmupCallback(warmup=0.1, schedule='linear'), # 训练过程中调整学习率。
|
||||
LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
|
||||
]
|
||||
# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
|
||||
@Trainer.on(Event.on_before_backward())
|
||||
def print_loss(trainer, outputs):
|
||||
if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
|
||||
print(outputs.loss.item())
|
||||
|
||||
trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
|
||||
device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
|
||||
callbacks=callbacks, monitor='acc#acc',n_epochs=5,
|
||||
# Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
|
||||
evaluate_input_mapping={'labels': 'target'}, # 在评测时,将dataloader中会输入到模型的labels重新命名为target
|
||||
evaluate_output_mapping={'logits': 'pred'} # 在评测时,将model输出中的logits重新命名为pred
|
||||
)
|
||||
trainer.run()
|
||||
|
||||
# 在测试集合上进行评测
|
||||
from fastNLP import Evaluator
|
||||
evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
|
||||
# Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
|
||||
output_mapping={'logits': 'pred'},
|
||||
input_mapping={'labels': 'target'})
|
||||
evaluator.run()
|
||||
```
|
||||
|
||||
|
||||
## fastNLP教程
|
||||
中文[文档](http://www.fastnlp.top/docs/fastNLP/)、 [教程](http://www.fastnlp.top/docs/fastNLP/user/quickstart.html)
|
||||
|
||||
更多内容可以参考如下的链接
|
||||
### 快速入门
|
||||
|
||||
- [Quick-1. 文本分类](http://www.fastnlp.top/docs/fastNLP/tutorials/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.html)
|
||||
- [Quick-2. 序列标注](http://www.fastnlp.top/docs/fastNLP/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.html)
|
||||
- [0. 10 分钟快速上手 fastNLP torch](http://www.fastnlp.top/docs/fastNLP/master/tutorials/torch/fastnlp_torch_tutorial.html)
|
||||
|
||||
### 详细使用教程
|
||||
|
||||
- [1. 使用DataSet预处理文本](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_1_data_preprocess.html)
|
||||
- [2. 使用Vocabulary转换文本与index](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html)
|
||||
- [3. 使用Embedding模块将文本转成向量](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_3_embedding.html)
|
||||
- [4. 使用Loader和Pipe加载并处理数据集](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_4_load_dataset.html)
|
||||
- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_5_loss_optimizer.html)
|
||||
- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html)
|
||||
- [7. 使用Metric快速评测你的模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_7_metrics.html)
|
||||
- [8. 使用Modules和Models快速搭建自定义模型](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_8_modules_models.html)
|
||||
- [9. 使用Callback自定义你的训练过程](http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_9_callback.html)
|
||||
|
||||
### 扩展教程
|
||||
|
||||
- [Extend-1. BertEmbedding的各种用法](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_1_bert_embedding.html)
|
||||
- [Extend-2. 分布式训练简介](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_2_dist.html)
|
||||
- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](http://www.fastnlp.top/docs/fastNLP/tutorials/extend_3_fitlog.html)
|
||||
- [1. Trainer 和 Evaluator 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_0.html)
|
||||
- [2. DataSet 和 Vocabulary 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_1.html)
|
||||
- [3. DataBundle 和 Tokenizer 的基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_2.html)
|
||||
- [4. TorchDataloader 的内部结构和基本使用](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_3.html)
|
||||
- [5. fastNLP 中的预定义模型](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_4.html)
|
||||
- [6. Trainer 和 Evaluator 的深入介绍](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_4.html)
|
||||
- [7. fastNLP 与 paddle 或 jittor 的结合](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_5.html)
|
||||
- [8. 使用 Bert + fine-tuning 完成 SST-2 分类](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_e1.html)
|
||||
- [9. 使用 Bert + prompt 完成 SST-2 分类](http://www.fastnlp.top/docs/fastNLP/master/tutorials/basic/fastnlp_tutorial_e2.html)
|
||||
|
||||
|
||||
## 内置组件
|
||||
</details>
|
||||
|
||||
大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
|
||||
<details>
|
||||
<summary>Paddle</summary>
|
||||
下面是使用paddle来进行文本分类的例子。需要安装paddle>=2.2.0以及paddlenlp>=2.3.3。
|
||||
|
||||
以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:
|
||||
```python
|
||||
from fastNLP.io import ChnSentiCorpLoader
|
||||
from functools import partial
|
||||
|
||||
# 会自动下载数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
|
||||
data_bundle = ChnSentiCorpLoader().load()
|
||||
|
||||
# 使用tokenizer对数据进行tokenize
|
||||
from paddlenlp.transformers import BertTokenizer
|
||||
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
|
||||
tokenize = partial(tokenizer, max_length=256) # 限制一下最大长度
|
||||
data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
|
||||
data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
|
||||
print(data_bundle.get_dataset('train')[:4])
|
||||
|
||||
# 初始化 model
|
||||
from paddlenlp.transformers import BertForSequenceClassification, LinearDecayWithWarmup
|
||||
from paddle import optimizer, nn
|
||||
class SeqClsModel(nn.Layer):
|
||||
def __init__(self, model_checkpoint, num_labels):
|
||||
super(SeqClsModel, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertForSequenceClassification.from_pretrained(model_checkpoint)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||
logits = self.bert(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
return logits
|
||||
|
||||
def train_step(self, input_ids, labels, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||
logits = self(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1, )))
|
||||
return {
|
||||
"logits": logits,
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
def evaluate_step(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||
logits = self(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
return {
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
model = SeqClsModel('hfl/chinese-bert-wwm', num_labels=2)
|
||||
|
||||
# 准备dataloader
|
||||
from fastNLP import prepare_dataloader
|
||||
dls = prepare_dataloader(data_bundle, batch_size=16)
|
||||
|
||||
# 训练过程中调整学习率。
|
||||
scheduler = LinearDecayWithWarmup(2e-5, total_steps=20 * len(dls['train']), warmup=0.1)
|
||||
optimizer = optimizer.AdamW(parameters=model.parameters(), learning_rate=scheduler)
|
||||
|
||||
# 准备训练
|
||||
from fastNLP import Trainer, Accuracy, LoadBestModelCallback, Event
|
||||
callbacks = [
|
||||
LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
|
||||
]
|
||||
# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
|
||||
@Trainer.on(Event.on_before_backward())
|
||||
def print_loss(trainer, outputs):
|
||||
if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
|
||||
print(outputs["loss"].item())
|
||||
|
||||
trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
|
||||
device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
|
||||
callbacks=callbacks, monitor='acc#acc',
|
||||
# Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
|
||||
evaluate_output_mapping={'logits': 'pred'},
|
||||
evaluate_input_mapping={'labels': 'target'}
|
||||
)
|
||||
trainer.run()
|
||||
|
||||
# 在测试集合上进行评测
|
||||
from fastNLP import Evaluator
|
||||
evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
|
||||
# Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
|
||||
output_mapping={'logits': 'pred'},
|
||||
input_mapping={'labels': 'target'})
|
||||
evaluator.run()
|
||||
```
|
||||
|
||||
更多内容可以参考如下的链接
|
||||
### 快速入门
|
||||
|
||||
- [0. 10 分钟快速上手 fastNLP paddle](http://www.fastnlp.top/docs/fastNLP/master/tutorials/torch/fastnlp_torch_tutorial.html)
|
||||
|
||||
### 详细使用教程
|
||||
|
||||
- [1. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析](http://www.fastnlp.top/docs/fastNLP/master/tutorials/paddle/fastnlp_tutorial_paddle_e1.html)
|
||||
- [2. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务](http://www.fastnlp.top/docs/fastNLP/master/tutorials/paddle/fastnlp_tutorial_paddle_e2.html)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>oneflow</summary>
|
||||
</details>
|
||||
|
||||
|
||||
![](./docs/source/figures/text_classification.png)
|
||||
|
||||
fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
|
||||
(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
|
||||
|
||||
与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><b> 类型 </b></td>
|
||||
<td><b> 功能 </b></td>
|
||||
<td><b> 例子 </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td> encoder </td>
|
||||
<td> 将输入编码为具有具有表示能力的向量 </td>
|
||||
<td> Embedding, RNN, CNN, Transformer, ...
|
||||
</tr>
|
||||
<tr>
|
||||
<td> decoder </td>
|
||||
<td> 将具有某种表示意义的向量解码为需要的输出形式 </td>
|
||||
<td> MLP, CRF, ... </td>
|
||||
</tr>
|
||||
</table>
|
||||
<details>
|
||||
<summary>jittor</summary>
|
||||
</details>
|
||||
|
||||
|
||||
## 项目结构
|
||||
|
||||
<div align=center><img width="450" height="350" src="./docs/source/figures/workflow.png"/></div>
|
||||
|
||||
|
||||
|
||||
fastNLP的大致工作流程如上图所示,而项目结构如下:
|
||||
fastNLP的项目结构如下:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
@ -135,4 +264,3 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
|
||||
|
||||
<hr>
|
||||
|
||||
*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
|
||||
|
@ -1,5 +0,0 @@
|
||||
ignore:
|
||||
- "reproduction" # ignore folders and all its contents
|
||||
- "setup.py"
|
||||
- "docs"
|
||||
- "tutorials"
|
@ -6,24 +6,35 @@ SPHINXOPTS =
|
||||
SPHINXAPIDOC = sphinx-apidoc
|
||||
SPHINXBUILD = sphinx-build
|
||||
SPHINXPROJ = fastNLP
|
||||
SPHINXEXCLUDE = ../fastNLP/transformers/*
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
PORT = 8000
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS)
|
||||
|
||||
apidoc:
|
||||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ)
|
||||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE)
|
||||
|
||||
server:
|
||||
cd build/html && python -m http.server
|
||||
cd build/html && python -m http.server $(PORT)
|
||||
|
||||
delete:
|
||||
rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build
|
||||
|
||||
web:
|
||||
make html && make server
|
||||
|
||||
dev:
|
||||
rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build && make apidoc && make html && make server
|
||||
make delete && make apidoc && make html && make server
|
||||
|
||||
versions:
|
||||
sphinx-multiversion "$(SOURCEDIR)" "$(BUILDDIR)" && cd build && python -m http.server $(PORT)
|
||||
|
||||
prod:
|
||||
make apidoc && make html
|
||||
make apidoc && make html
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
# 快速入门 fastNLP 文档编写
|
||||
|
||||
本教程为 fastNLP 文档编写者创建,文档编写者包括合作开发人员和文档维护人员。您在一般情况下属于前者,
|
||||
只需要了解整个框架的部分内容即可。
|
||||
|
||||
## 合作开发人员
|
||||
|
||||
FastNLP的文档使用基于[reStructuredText标记语言](http://docutils.sourceforge.net/rst.html)的
|
||||
[Sphinx](http://sphinx.pocoo.org/)工具生成,由[Read the Docs](https://readthedocs.org/)网站自动维护生成。
|
||||
一般开发者只要编写符合reStructuredText语法规范的文档并通过[PR](https://help.github.com/en/articles/about-pull-requests),
|
||||
就可以为fastNLP的文档贡献一份力量。
|
||||
|
||||
如果你想在本地编译文档并进行大段文档的编写,您需要安装Sphinx工具以及sphinx-rtd-theme主题:
|
||||
```bash
|
||||
fastNLP/docs> pip install sphinx
|
||||
fastNLP/docs> pip install sphinx-rtd-theme
|
||||
```
|
||||
然后在本目录下执行 `make dev` 命令。该命令只支持Linux和MacOS系统,期望看到如下输出:
|
||||
```bash
|
||||
fastNLP/docs> make dev
|
||||
rm -rf build/html && make html && make server
|
||||
Running Sphinx v1.5.6
|
||||
making output directory...
|
||||
......
|
||||
Build finished. The HTML pages are in build/html.
|
||||
cd build/html && python -m http.server
|
||||
Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
|
||||
```
|
||||
现在您浏览器访问 http://localhost:8000/ 查看文档。如果你在远程服务器尚进行工作,则访问地址为 http://{服务器的ip地址}:8000/ 。
|
||||
但您必须保证服务器的8000端口是开放的。如果您的电脑或远程服务器的8000端口被占用,程序会顺延使用8001、8002……等端口。
|
||||
当你结束访问时,您可以使用Control(Ctrl) + C 来结束进程。
|
||||
|
||||
我们在[这里](./source/user/example.rst)列举了fastNLP文档经常用到的reStructuredText语法(网页查看请结合Raw模式),
|
||||
您可以通过阅读它进行快速上手。FastNLP大部分的文档都是写在代码中通过Sphinx工具进行抽取生成的,
|
||||
|
||||
## 文档维护人员
|
||||
|
||||
文档维护人员需要了解 Makefile 中全部命令的含义,并了解到目前的文档结构
|
||||
是在 sphinx-apidoc 自动抽取的基础上进行手动修改得到的。
|
||||
文档维护人员应进一步提升整个框架的自动化程度,并监督合作开发人员不要破坏文档项目的整体结构。
|
@ -1,191 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def _colored_string(string: str, color: str or int) -> str:
|
||||
"""在终端中显示一串有颜色的文字
|
||||
:param string: 在终端中显示的文字
|
||||
:param color: 文字的颜色
|
||||
:return:
|
||||
"""
|
||||
if isinstance(color, str):
|
||||
color = {
|
||||
"black": 30, "Black": 30, "BLACK": 30,
|
||||
"red": 31, "Red": 31, "RED": 31,
|
||||
"green": 32, "Green": 32, "GREEN": 32,
|
||||
"yellow": 33, "Yellow": 33, "YELLOW": 33,
|
||||
"blue": 34, "Blue": 34, "BLUE": 34,
|
||||
"purple": 35, "Purple": 35, "PURPLE": 35,
|
||||
"cyan": 36, "Cyan": 36, "CYAN": 36,
|
||||
"white": 37, "White": 37, "WHITE": 37
|
||||
}[color]
|
||||
return "\033[%dm%s\033[0m" % (color, string)
|
||||
|
||||
|
||||
def gr(string, flag):
|
||||
if flag:
|
||||
return _colored_string(string, "green")
|
||||
else:
|
||||
return _colored_string(string, "red")
|
||||
|
||||
|
||||
def find_all_modules():
|
||||
modules = {}
|
||||
children = {}
|
||||
to_doc = set()
|
||||
root = '../fastNLP'
|
||||
for path, dirs, files in os.walk(root):
|
||||
for file in files:
|
||||
if file.endswith('.py'):
|
||||
name = ".".join(path.split('/')[1:])
|
||||
if file.split('.')[0] != "__init__":
|
||||
name = name + '.' + file.split('.')[0]
|
||||
__import__(name)
|
||||
m = sys.modules[name]
|
||||
modules[name] = m
|
||||
try:
|
||||
m.__all__
|
||||
except:
|
||||
print(name, "__all__ missing")
|
||||
continue
|
||||
if m.__doc__ is None:
|
||||
print(name, "__doc__ missing")
|
||||
continue
|
||||
if "undocumented" not in m.__doc__:
|
||||
to_doc.add(name)
|
||||
for module in to_doc:
|
||||
t = ".".join(module.split('.')[:-1])
|
||||
if t in to_doc:
|
||||
if t not in children:
|
||||
children[t] = set()
|
||||
children[t].add(module)
|
||||
for m in children:
|
||||
children[m] = sorted(children[m])
|
||||
return modules, to_doc, children
|
||||
|
||||
|
||||
def create_rst_file(modules, name, children):
|
||||
m = modules[name]
|
||||
with open("./source/" + name + ".rst", "w") as fout:
|
||||
t = "=" * len(name)
|
||||
fout.write(name + "\n")
|
||||
fout.write(t + "\n")
|
||||
fout.write("\n")
|
||||
fout.write(".. automodule:: " + name + "\n")
|
||||
if name != "fastNLP.core" and len(m.__all__) > 0:
|
||||
fout.write(" :members: " + ", ".join(m.__all__) + "\n")
|
||||
short = name[len("fastNLP."):]
|
||||
if not (short.startswith('models') or short.startswith('modules') or short.startswith('embeddings')):
|
||||
fout.write(" :inherited-members:\n")
|
||||
fout.write("\n")
|
||||
if name in children:
|
||||
fout.write("子模块\n------\n\n.. toctree::\n :maxdepth: 1\n\n")
|
||||
for module in children[name]:
|
||||
fout.write(" " + module + "\n")
|
||||
|
||||
|
||||
def check_file(m, name):
|
||||
names = name.split('.')
|
||||
test_name = "test." + ".".join(names[1:-1]) + ".test_" + names[-1]
|
||||
try:
|
||||
__import__(test_name)
|
||||
tm = sys.modules[test_name]
|
||||
except ModuleNotFoundError:
|
||||
tm = None
|
||||
tested = tm is not None
|
||||
funcs = {}
|
||||
classes = {}
|
||||
for item, obj in inspect.getmembers(m):
|
||||
if inspect.isclass(obj) and obj.__module__ == name and not obj.__name__.startswith('_'):
|
||||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm), {})
|
||||
for i in dir(obj):
|
||||
func = getattr(obj, i)
|
||||
if inspect.isfunction(func) and not i.startswith('_'):
|
||||
this[2][i] = (func.__doc__ is not None, False)
|
||||
classes[obj.__name__] = this
|
||||
if inspect.isfunction(obj) and obj.__module__ == name and not obj.__name__.startswith('_'):
|
||||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm)) # docs
|
||||
funcs[obj.__name__] = this
|
||||
return funcs, classes
|
||||
|
||||
|
||||
def check_files(modules, out=None):
|
||||
for name in sorted(modules.keys()):
|
||||
print(name, file=out)
|
||||
funcs, classes = check_file(modules[name], name)
|
||||
if out is None:
|
||||
for f in funcs:
|
||||
print("%-30s \t %s \t %s" % (f, gr("文档", funcs[f][0]), gr("测试", funcs[f][1])))
|
||||
for c in classes:
|
||||
print("%-30s \t %s \t %s" % (c, gr("文档", classes[c][0]), gr("测试", classes[c][1])))
|
||||
methods = classes[c][2]
|
||||
for f in methods:
|
||||
print(" %-28s \t %s" % (f, gr("文档", methods[f][0])))
|
||||
else:
|
||||
for f in funcs:
|
||||
if not funcs[f][0]:
|
||||
print("缺少文档 %s" % (f), file=out)
|
||||
if not funcs[f][1]:
|
||||
print("缺少测试 %s" % (f), file=out)
|
||||
for c in classes:
|
||||
if not classes[c][0]:
|
||||
print("缺少文档 %s" % (c), file=out)
|
||||
if not classes[c][1]:
|
||||
print("缺少测试 %s" % (c), file=out)
|
||||
methods = classes[c][2]
|
||||
for f in methods:
|
||||
if not methods[f][0]:
|
||||
print("缺少文档 %s" % (c + "." + f), file=out)
|
||||
print(file=out)
|
||||
|
||||
|
||||
def main_check():
|
||||
sys.path.append("..")
|
||||
print(_colored_string('Getting modules...', "Blue"))
|
||||
modules, to_doc, children = find_all_modules()
|
||||
print(_colored_string('Done!', "Green"))
|
||||
print(_colored_string('Creating rst files...', "Blue"))
|
||||
for name in to_doc:
|
||||
create_rst_file(modules, name, children)
|
||||
print(_colored_string('Done!', "Green"))
|
||||
print(_colored_string('Checking all files...', "Blue"))
|
||||
check_files(modules, out=open("results.txt", "w"))
|
||||
print(_colored_string('Done!', "Green"))
|
||||
|
||||
|
||||
def check_file_r(file_path):
|
||||
with open(file_path) as fin:
|
||||
content = fin.read()
|
||||
index = -3
|
||||
cuts = []
|
||||
while index != -1:
|
||||
index = content.find('"""',index+3)
|
||||
cuts.append(index)
|
||||
cuts = cuts[:-1]
|
||||
assert len(cuts)%2 == 0
|
||||
write_content = ""
|
||||
last = 0
|
||||
for i in range(len(cuts)//2):
|
||||
start, end = cuts[i+i], cuts[i+i+1]
|
||||
if content[start-1] == "r":
|
||||
write_content += content[last:end+3]
|
||||
else:
|
||||
write_content += content[last:start] + "r"
|
||||
write_content += content[start:end+3]
|
||||
last = end + 3
|
||||
write_content += content[last:]
|
||||
with open(file_path, "w") as fout:
|
||||
fout.write(write_content)
|
||||
|
||||
|
||||
def add_r(base_path='../fastNLP'):
|
||||
for path, _, files in os.walk(base_path):
|
||||
for f in files:
|
||||
if f.endswith(".py"):
|
||||
check_file_r(os.path.abspath(os.path.join(path,f)))
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_r()
|
@ -1,4 +1,4 @@
|
||||
sphinx==3.2.1
|
||||
docutils==0.16
|
||||
sphinx-rtd-theme==0.5.0
|
||||
readthedocs-sphinx-search==0.1.0rc3
|
||||
sphinx
|
||||
sphinx_rtd_theme
|
||||
sphinx_autodoc_typehints
|
||||
sphinx-multiversion
|
@ -1,260 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BertEmbedding的各种用法\n",
|
||||
"Bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。\n",
|
||||
"\n",
|
||||
"为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。\n",
|
||||
"\n",
|
||||
"## 1. 使用Bert进行文本分类\n",
|
||||
"\n",
|
||||
"文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类\n",
|
||||
"\n",
|
||||
" *1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!*\n",
|
||||
"\n",
|
||||
"这里我们使用fastNLP提供自动下载的微博分类进行测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.io import WeiboSenti100kPipe\n",
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"from fastNLP.models import BertForSequenceClassification\n",
|
||||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"data_bundle =WeiboSenti100kPipe().process_from_file()\n",
|
||||
"data_bundle.rename_field('chars', 'words')\n",
|
||||
"\n",
|
||||
"# 载入BertEmbedding\n",
|
||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
|
||||
"\n",
|
||||
"# 载入模型\n",
|
||||
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))\n",
|
||||
"\n",
|
||||
"# 训练模型\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
|
||||
" optimizer=Adam(model_params=model.parameters(), lr=2e-5),\n",
|
||||
" loss=CrossEntropyLoss(), device=device,\n",
|
||||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
|
||||
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"# 测试结果\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"\n",
|
||||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. 使用Bert进行命名实体识别\n",
|
||||
"\n",
|
||||
"命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 两句话,例如下面的例子\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
" 中 B-ORG\n",
|
||||
" 共 I-ORG\n",
|
||||
" 中 I-ORG\n",
|
||||
" 央 I-ORG\n",
|
||||
" 致 O\n",
|
||||
" 中 B-ORG\n",
|
||||
" 国 I-ORG\n",
|
||||
" 致 I-ORG\n",
|
||||
" 公 I-ORG\n",
|
||||
" 党 I-ORG\n",
|
||||
" 十 I-ORG\n",
|
||||
" 一 I-ORG\n",
|
||||
" 大 I-ORG\n",
|
||||
" 的 O\n",
|
||||
" 贺 O\n",
|
||||
" 词 O\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"这部分内容请参考 快速实现序列标注模型\n",
|
||||
"\n",
|
||||
"## 3. 使用Bert进行文本匹配\n",
|
||||
"\n",
|
||||
"文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.io import CNXNLIBertPipe\n",
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"from fastNLP.models import BertForSentenceMatching\n",
|
||||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
|
||||
"from fastNLP.core.optimizer import AdamW\n",
|
||||
"from fastNLP.core.callback import WarmupCallback\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"data_bundle = CNXNLIBertPipe().process_from_file()\n",
|
||||
"data_bundle.rename_field('chars', 'words')\n",
|
||||
"print(data_bundle)\n",
|
||||
"\n",
|
||||
"# 载入BertEmbedding\n",
|
||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n",
|
||||
"\n",
|
||||
"# 载入模型\n",
|
||||
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))\n",
|
||||
"\n",
|
||||
"# 训练模型\n",
|
||||
"callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n",
|
||||
" optimizer=AdamW(params=model.parameters(), lr=4e-5),\n",
|
||||
" loss=CrossEntropyLoss(), device=device,\n",
|
||||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n",
|
||||
" metrics=AccuracyMetric(), n_epochs=5, print_every=1,\n",
|
||||
" update_every=8, callbacks=callbacks)\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. 使用Bert进行中文问答\n",
|
||||
"\n",
|
||||
"问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 例如:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"\"context\": \"锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常\n",
|
||||
"用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及\n",
|
||||
"作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合\n",
|
||||
"相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单\n",
|
||||
"皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大\n",
|
||||
"钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师\n",
|
||||
"傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼\n",
|
||||
"和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:\",\n",
|
||||
"\"question\": \"锣鼓经是什么?\",\n",
|
||||
"\"answers\": [\n",
|
||||
" {\n",
|
||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||||
" \"answer_start\": 4\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||||
" \"answer_start\": 4\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n",
|
||||
" \"answer_start\": 4\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"您可以通过以下的代码训练 (原文代码:[CMRC2018](https://github.com/ymcui/cmrc2018) )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"from fastNLP.models import BertForQuestionAnswering\n",
|
||||
"from fastNLP.core.losses import CMRC2018Loss\n",
|
||||
"from fastNLP.core.metrics import CMRC2018Metric\n",
|
||||
"from fastNLP.io.pipe.qa import CMRC2018BertPipe\n",
|
||||
"from fastNLP import Trainer, BucketSampler\n",
|
||||
"from fastNLP import WarmupCallback, GradientClipCallback\n",
|
||||
"from fastNLP.core.optimizer import AdamW\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"data_bundle = CMRC2018BertPipe().process_from_file()\n",
|
||||
"data_bundle.rename_field('chars', 'words')\n",
|
||||
"\n",
|
||||
"print(data_bundle)\n",
|
||||
"\n",
|
||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,\n",
|
||||
" dropout=0.5, word_dropout=0.01)\n",
|
||||
"model = BertForQuestionAnswering(embed)\n",
|
||||
"loss = CMRC2018Loss()\n",
|
||||
"metric = CMRC2018Metric()\n",
|
||||
"\n",
|
||||
"wm_callback = WarmupCallback(schedule='linear')\n",
|
||||
"gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')\n",
|
||||
"callbacks = [wm_callback, gc_callback]\n",
|
||||
"\n",
|
||||
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
|
||||
"\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n",
|
||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
|
||||
" sampler=BucketSampler(seq_len_field_name='context_len'),\n",
|
||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric,\n",
|
||||
" callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,\n",
|
||||
" test_use_tqdm=False, update_every=10)\n",
|
||||
"trainer.train(load_best_model=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"训练结果(和论文中报道的基本一致):\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
" In Epoch:2/Step:1692, got best dev performance:\n",
|
||||
" CMRC2018Metric: f1=85.61, em=66.08\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,292 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# fastNLP中的DataSet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+------------------------------+---------------------------------------------+---------+\n",
|
||||
"| raw_words | words | seq_len |\n",
|
||||
"+------------------------------+---------------------------------------------+---------+\n",
|
||||
"| This is the first instance . | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
|
||||
"| Second instance . | ['Second', 'instance', '.'] | 3 |\n",
|
||||
"| Third instance . | ['Third', 'instance', '.'] | 3 |\n",
|
||||
"+------------------------------+---------------------------------------------+---------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import DataSet\n",
|
||||
"data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"],\n",
|
||||
" 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.']],\n",
|
||||
" 'seq_len': [6, 3, 3]}\n",
|
||||
"dataset = DataSet(data)\n",
|
||||
"# 传入的dict的每个key的value应该为具有相同长度的list\n",
|
||||
"print(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DataSet的构建"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"+----------------------------+---------------------------------------------+---------+\n",
|
||||
"| raw_words | words | seq_len |\n",
|
||||
"+----------------------------+---------------------------------------------+---------+\n",
|
||||
"| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
|
||||
"+----------------------------+---------------------------------------------+---------+"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import DataSet\n",
|
||||
"from fastNLP import Instance\n",
|
||||
"dataset = DataSet()\n",
|
||||
"instance = Instance(raw_words=\"This is the first instance\",\n",
|
||||
" words=['this', 'is', 'the', 'first', 'instance', '.'],\n",
|
||||
" seq_len=6)\n",
|
||||
"dataset.append(instance)\n",
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"+----------------------------+---------------------------------------------+---------+\n",
|
||||
"| raw_words | words | seq_len |\n",
|
||||
"+----------------------------+---------------------------------------------+---------+\n",
|
||||
"| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n",
|
||||
"| Second instance . | ['Second', 'instance', '.'] | 3 |\n",
|
||||
"+----------------------------+---------------------------------------------+---------+"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import DataSet\n",
|
||||
"from fastNLP import Instance\n",
|
||||
"dataset = DataSet([\n",
|
||||
" Instance(raw_words=\"This is the first instance\",\n",
|
||||
" words=['this', 'is', 'the', 'first', 'instance', '.'],\n",
|
||||
" seq_len=6),\n",
|
||||
" Instance(raw_words=\"Second instance .\",\n",
|
||||
" words=['Second', 'instance', '.'],\n",
|
||||
" seq_len=3)\n",
|
||||
" ])\n",
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DataSet的删除"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"+----+---+\n",
|
||||
"| a | c |\n",
|
||||
"+----+---+\n",
|
||||
"| -5 | 0 |\n",
|
||||
"| -4 | 0 |\n",
|
||||
"| -3 | 0 |\n",
|
||||
"| -2 | 0 |\n",
|
||||
"| -1 | 0 |\n",
|
||||
"| 0 | 0 |\n",
|
||||
"| 1 | 0 |\n",
|
||||
"| 2 | 0 |\n",
|
||||
"| 3 | 0 |\n",
|
||||
"| 4 | 0 |\n",
|
||||
"+----+---+"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import DataSet\n",
|
||||
"dataset = DataSet({'a': range(-5, 5), 'c': [0]*10})\n",
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"+---+\n",
|
||||
"| c |\n",
|
||||
"+---+\n",
|
||||
"| 0 |\n",
|
||||
"| 0 |\n",
|
||||
"| 0 |\n",
|
||||
"| 0 |\n",
|
||||
"+---+"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 不改变dataset,生成一个删除了满足条件的instance的新 DataSet\n",
|
||||
"dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)\n",
|
||||
"# 在dataset中删除满足条件的instance\n",
|
||||
"dataset.drop(lambda ins:ins['a']<0)\n",
|
||||
"# 删除第3个instance\n",
|
||||
"dataset.delete_instance(2)\n",
|
||||
"# 删除名为'a'的field\n",
|
||||
"dataset.delete_field('a')\n",
|
||||
"dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 简单的数据预处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"False\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 检查是否存在名为'a'的field\n",
|
||||
"print(dataset.has_field('a')) # 或 ('a' in dataset)\n",
|
||||
"# 将名为'a'的field改名为'b'\n",
|
||||
"dataset.rename_field('c', 'b')\n",
|
||||
"# DataSet的长度\n",
|
||||
"len(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"+------------------------------+-------------------------------------------------+\n",
|
||||
"| raw_words | words |\n",
|
||||
"+------------------------------+-------------------------------------------------+\n",
|
||||
"| This is the first instance . | ['This', 'is', 'the', 'first', 'instance', '.'] |\n",
|
||||
"| Second instance . | ['Second', 'instance', '.'] |\n",
|
||||
"| Third instance . | ['Third', 'instance', '.'] |\n",
|
||||
"+------------------------------+-------------------------------------------------+"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import DataSet\n",
|
||||
"data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"]}\n",
|
||||
"dataset = DataSet(data)\n",
|
||||
"\n",
|
||||
"# 将句子分成单词形式, 详见DataSet.apply()方法\n",
|
||||
"dataset.apply(lambda ins: ins['raw_words'].split(), new_field_name='words')\n",
|
||||
"\n",
|
||||
"# 或使用DataSet.apply_field()\n",
|
||||
"dataset.apply_field(lambda sent:sent.split(), field_name='raw_words', new_field_name='words')\n",
|
||||
"\n",
|
||||
"# 除了匿名函数,也可以定义函数传递进去\n",
|
||||
"def get_words(instance):\n",
|
||||
" sentence = instance['raw_words']\n",
|
||||
" words = sentence.split()\n",
|
||||
" return words\n",
|
||||
"dataset.apply(get_words, new_field_name='words')\n",
|
||||
"dataset"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,343 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# fastNLP中的 Vocabulary\n",
|
||||
"## 构建 Vocabulary"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(['复', '旦', '大', '学']) # 加入新的字\n",
|
||||
"vocab.add_word('上海') # `上海`会作为一个整体\n",
|
||||
"vocab.to_index('复') # 应该会为3\n",
|
||||
"vocab.to_index('我') # 会输出1,Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1\n",
|
||||
"\n",
|
||||
"# 在构建target的Vocabulary时,词表中应该用不上pad和unk,可以通过以下的初始化\n",
|
||||
"vocab = Vocabulary(unknown=None, padding=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Vocabulary(['positive', 'negative']...)"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab.add_word_lst(['positive', 'negative'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab.to_index('positive')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 没有设置 unk 的情况"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "word `neutral` not in vocabulary",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-4-c6d424040b45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'neutral'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 会报错,因为没有unk这种情况\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36mto_index\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m \u001b[0mint\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mnumber\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \"\"\"\n\u001b[0;32m--> 416\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 417\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrebuild\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_vocab\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munknown\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"word `{}` not in vocabulary\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_check_build_vocab\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mValueError\u001b[0m: word `neutral` not in vocabulary"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab.to_index('neutral') # 会报错,因为没有unk这种情况"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 设置 unk 的情况"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0, '<unk>')"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary(unknown='<unk>', padding=None)\n",
|
||||
"vocab.add_word_lst(['positive', 'negative'])\n",
|
||||
"vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Vocabulary(['positive', 'negative']...)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+---------------------------------------------------+--------+\n",
|
||||
"| chars | target |\n",
|
||||
"+---------------------------------------------------+--------+\n",
|
||||
"| [4, 2, 2, 5, 6, 7, 3] | 0 |\n",
|
||||
"| [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1 |\n",
|
||||
"+---------------------------------------------------+--------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"from fastNLP import DataSet\n",
|
||||
"\n",
|
||||
"dataset = DataSet({'chars': [\n",
|
||||
" ['今', '天', '天', '气', '很', '好', '。'],\n",
|
||||
" ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
|
||||
" ],\n",
|
||||
" 'target': ['neutral', 'negative']\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.from_dataset(dataset, field_name='chars')\n",
|
||||
"vocab.index_dataset(dataset, field_name='chars')\n",
|
||||
"\n",
|
||||
"target_vocab = Vocabulary(padding=None, unknown=None)\n",
|
||||
"target_vocab.from_dataset(dataset, field_name='target')\n",
|
||||
"target_vocab.index_dataset(dataset, field_name='target')\n",
|
||||
"print(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Vocabulary(['今', '天', '心', '情', '很']...)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"from fastNLP import DataSet\n",
|
||||
"\n",
|
||||
"tr_data = DataSet({'chars': [\n",
|
||||
" ['今', '天', '心', '情', '很', '好', '。'],\n",
|
||||
" ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n",
|
||||
" ],\n",
|
||||
" 'target': ['positive', 'negative']\n",
|
||||
"})\n",
|
||||
"dev_data = DataSet({'chars': [\n",
|
||||
" ['住', '宿', '条', '件', '还', '不', '错'],\n",
|
||||
" ['糟', '糕', '的', '天', '气', ',', '无', '法', '出', '行', '。']\n",
|
||||
" ],\n",
|
||||
" 'target': ['positive', 'negative']\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"# 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。\n",
|
||||
"vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 4%|▎ | 2.31M/63.5M [00:00<00:02, 22.9MB/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"http://212.129.155.247/embedding/glove.6B.50d.zip not found in cache, downloading to /tmp/tmpvziobj_e\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 63.5M/63.5M [00:01<00:00, 41.3MB/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Finish download from http://212.129.155.247/embedding/glove.6B.50d.zip\n",
|
||||
"Copy file to /remote-home/ynzheng/.fastNLP/embedding/glove.6B.50d\n",
|
||||
"Found 2 out of 6 words in the pre-training embedding.\n",
|
||||
"tensor([[ 0.9497, 0.3433, 0.8450, -0.8852, -0.7208, -0.2931, -0.7468, 0.6512,\n",
|
||||
" 0.4730, -0.7401, 0.1877, -0.3828, -0.5590, 0.4295, -0.2698, -0.4238,\n",
|
||||
" -0.3124, 1.3423, -0.7857, -0.6302, 0.9182, 0.2113, -0.5744, 1.4549,\n",
|
||||
" 0.7546, -1.6165, -0.0085, 0.0029, 0.5130, -0.4745, 2.5306, 0.8594,\n",
|
||||
" -0.3067, 0.0578, 0.6623, 0.2080, 0.6424, -0.5246, -0.0534, 1.1404,\n",
|
||||
" -0.1370, -0.1836, 0.4546, -0.5096, -0.0255, -0.0286, 0.1805, -0.4483,\n",
|
||||
" 0.4053, -0.3682]], grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.1320, -0.2392, 0.1732, -0.2390, -0.0463, 0.0494, 0.0488, -0.0886,\n",
|
||||
" 0.0224, -0.1300, 0.0369, 0.1800, 0.0750, -0.0183, 0.2264, 0.1628,\n",
|
||||
" 0.1261, -0.1259, 0.1663, -0.1230, -0.1904, -0.0532, 0.1397, -0.0259,\n",
|
||||
" -0.1799, 0.0226, 0.1858, 0.1981, 0.1338, 0.2394, 0.0248, 0.0203,\n",
|
||||
" -0.1722, -0.1683, -0.1892, 0.0874, 0.0562, -0.0394, 0.0306, -0.1761,\n",
|
||||
" 0.1015, -0.0171, 0.1172, 0.1357, 0.1519, -0.0011, 0.1572, 0.1265,\n",
|
||||
" -0.2391, -0.0258]], grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.1318, -0.2552, -0.0679, 0.2619, -0.2616, 0.2357, 0.1308, -0.0118,\n",
|
||||
" 1.7659, 0.2078, 0.2620, -0.1643, -0.8464, 0.0201, 0.0702, 0.3978,\n",
|
||||
" 0.1528, -0.2021, -1.6184, -0.5433, -0.1786, 0.5389, 0.4987, -0.1017,\n",
|
||||
" 0.6626, -1.7051, 0.0572, -0.3241, -0.6683, 0.2665, 2.8420, 0.2684,\n",
|
||||
" -0.5954, -0.5004, 1.5199, 0.0396, 1.6659, 0.9976, -0.5597, -0.7049,\n",
|
||||
" -0.0309, -0.2830, -0.1356, 0.6429, 0.4149, 1.2362, 0.7659, 0.9780,\n",
|
||||
" 0.5851, -0.3018]], grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0.]], grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
|
||||
" 0., 0.]], grad_fn=<EmbeddingBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word('train')\n",
|
||||
"vocab.add_word('only_in_train') # 仅在train出现,但肯定在预训练词表中不存在\n",
|
||||
"vocab.add_word('test', no_create_entry=True) # 该词只在dev或test中出现\n",
|
||||
"vocab.add_word('only_in_test', no_create_entry=True) # 这个词在预训练的词表中找不到\n",
|
||||
"\n",
|
||||
"embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('train')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('only_in_train')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('test')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('only_in_test')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,524 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 5 out of 7 words in the pre-training embedding.\n",
|
||||
"torch.Size([1, 5, 50])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
|
||||
"\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]]) # 将文本转为index\n",
|
||||
"print(embed(words).size()) # StaticEmbedding的使用和pytorch的nn.Embedding是类似的"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 5, 30])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30)\n",
|
||||
"\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"22 out of 22 characters were found in pretrained elmo embedding.\n",
|
||||
"torch.Size([1, 5, 256])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import ElmoEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False)\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"22 out of 22 characters were found in pretrained elmo embedding.\n",
|
||||
"torch.Size([1, 5, 512])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2')\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"22 out of 22 characters were found in pretrained elmo embedding.\n",
|
||||
"torch.Size([1, 5, 256])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix')\n",
|
||||
"print(embed(words).size()) # 三层输出按照权重element-wise的加起来"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"torch.Size([1, 5, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased')\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"torch.Size([1, 5, 1536])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 使用后面两层的输出\n",
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11')\n",
|
||||
"print(embed(words).size()) # 结果将是在最后一维做拼接"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"torch.Size([1, 7, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True)\n",
|
||||
"print(embed(words).size()) # 结果将在序列维度上增加2\n",
|
||||
"# 取出句子的cls表示\n",
|
||||
"cls_reps = embed(words)[:, 0] # shape: [batch_size, 768]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"torch.Size([1, 5, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 10 words out of 10.\n",
|
||||
"torch.Size([1, 9, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo . [SEP] another sentence .\".split())\n",
|
||||
"\n",
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo . [SEP] another sentence .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Start constructing character vocabulary.\n",
|
||||
"In total, there are 8 distinct characters.\n",
|
||||
"torch.Size([1, 5, 64])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import CNNCharEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n",
|
||||
"embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Start constructing character vocabulary.\n",
|
||||
"In total, there are 8 distinct characters.\n",
|
||||
"torch.Size([1, 5, 64])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import LSTMCharEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n",
|
||||
"embed = LSTMCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
|
||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
|
||||
"print(embed(words).size())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 5 out of 7 words in the pre-training embedding.\n",
|
||||
"50\n",
|
||||
"Start constructing character vocabulary.\n",
|
||||
"In total, there are 8 distinct characters.\n",
|
||||
"30\n",
|
||||
"22 out of 22 characters were found in pretrained elmo embedding.\n",
|
||||
"256\n",
|
||||
"22 out of 22 characters were found in pretrained elmo embedding.\n",
|
||||
"512\n",
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"768\n",
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n",
|
||||
"1536\n",
|
||||
"80\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import *\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
|
||||
"print(static_embed.embedding_dim) # 50\n",
|
||||
"char_embed = CNNCharEmbedding(vocab, embed_size=30)\n",
|
||||
"print(char_embed.embedding_dim) # 30\n",
|
||||
"elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2')\n",
|
||||
"print(elmo_embed_1.embedding_dim) # 256\n",
|
||||
"elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2')\n",
|
||||
"print(elmo_embed_2.embedding_dim) # 512\n",
|
||||
"bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased')\n",
|
||||
"print(bert_embed_1.embedding_dim) # 768\n",
|
||||
"bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased')\n",
|
||||
"print(bert_embed_2.embedding_dim) # 1536\n",
|
||||
"stack_embed = StackEmbedding([static_embed, char_embed])\n",
|
||||
"print(stack_embed.embedding_dim) # 80"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 7 words out of 7.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import *\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"vocab.add_word_lst(\"this is a demo .\".split())\n",
|
||||
"\n",
|
||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', requires_grad=True) # 初始化时设定为需要更新\n",
|
||||
"embed.requires_grad = False # 修改BertEmbedding的权重为不更新"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[ 0.3633, -0.2091, -0.0353, -0.3771, -0.5193]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.0926, -0.4812, -0.7744, 0.4836, -0.5475]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
|
||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练词向量时效果是一致的\n",
|
||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5)\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"All word in the vocab have been lowered. There are 6 words, 4 unique lowered words.\n",
|
||||
"tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
|
||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
|
||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True)\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1 out of 4 words have frequency less than 2.\n",
|
||||
"tensor([[ 0.4724, -0.7277, -0.6350, -0.5258, -0.6063]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary().add_word_lst(\"the the the a\".split())\n",
|
||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
|
||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2)\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 out of 5 words have frequency less than 2.\n",
|
||||
"All word in the vocab have been lowered. There are 5 words, 4 unique lowered words.\n",
|
||||
"tensor([[ 0.1943, 0.3739, 0.2769, -0.4746, -0.3181]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n",
|
||||
"tensor([[-0.1348, -0.2172, -0.0071, 0.5704, -0.2607]],\n",
|
||||
" grad_fn=<EmbeddingBackward>)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary().add_word_lst(\"the the the a A\".split())\n",
|
||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n",
|
||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True)\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.to_index('A')])))\n",
|
||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,309 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 使用Loader和Pipe加载并处理数据集\n",
|
||||
"\n",
|
||||
"这一部分是关于如何加载数据集的教程\n",
|
||||
"\n",
|
||||
"## Part I: 数据集容器DataBundle\n",
|
||||
"\n",
|
||||
"而由于对于同一个任务,训练集,验证集和测试集会共用同一个词表以及具有相同的目标值,所以在fastNLP中我们使用了 DataBundle 来承载同一个任务的多个数据集 DataSet 以及它们的词表 Vocabulary 。下面会有例子介绍 DataBundle 的相关使用。\n",
|
||||
"\n",
|
||||
"DataBundle 在fastNLP中主要在各个 Loader 和 Pipe 中被使用。 下面我们先介绍一下 Loader 和 Pipe 。\n",
|
||||
"\n",
|
||||
"## Part II: 加载的各种数据集的Loader\n",
|
||||
"\n",
|
||||
"在fastNLP中,所有的 Loader 都可以通过其文档判断其支持读取的数据格式,以及读取之后返回的 DataSet 的格式, 例如 ChnSentiCorpLoader \n",
|
||||
"\n",
|
||||
"- download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。\n",
|
||||
"\n",
|
||||
"- _load() 函数:从一个数据文件中读取数据,返回一个 DataSet 。返回的DataSet的格式可从Loader文档判断。\n",
|
||||
"\n",
|
||||
"- load() 函数:从文件或者文件夹中读取数据为 DataSet 并将它们组装成 DataBundle。支持接受的参数类型有以下的几种\n",
|
||||
"\n",
|
||||
" - None, 将尝试读取自动缓存的数据,仅支持提供了自动下载数据的Loader\n",
|
||||
" - 文件夹路径, 默认将尝试在该文件夹下匹配文件名中含有 train , test , dev 的文件,如果有多个文件含有相同的关键字,将无法通过该方式读取\n",
|
||||
" - dict, 例如{'train':\"/path/to/tr.conll\", 'dev':\"/to/validate.conll\", \"test\":\"/to/te.conll\"}。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In total 3 datasets:\n",
|
||||
"\ttest has 1944 instances.\n",
|
||||
"\ttrain has 17196 instances.\n",
|
||||
"\tdev has 1858 instances.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.io import CWSLoader\n",
|
||||
"\n",
|
||||
"loader = CWSLoader(dataset_name='pku')\n",
|
||||
"data_bundle = loader.load()\n",
|
||||
"print(data_bundle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"这里表示一共有3个数据集。其中:\n",
|
||||
"\n",
|
||||
" 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n",
|
||||
"\n",
|
||||
"也可以取出DataSet,并打印DataSet中的具体内容"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+----------------------------------------------------------------+\n",
|
||||
"| raw_words |\n",
|
||||
"+----------------------------------------------------------------+\n",
|
||||
"| 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ... |\n",
|
||||
"| 中共中央 总书记 、 国家 主席 江 泽民 |\n",
|
||||
"+----------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tr_data = data_bundle.get_dataset('train')\n",
|
||||
"print(tr_data[:2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part III: 使用Pipe对数据集进行预处理\n",
|
||||
"\n",
|
||||
"通过 Loader 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。\n",
|
||||
"\n",
|
||||
"在fastNLP中,我们使用 Pipe 的子类作为数据预处理的类, Loader 和 Pipe 一般具备一一对应的关系,该关系可以从其名称判断, 例如 CWSLoader 与 CWSPipe 是一一对应的。一般情况下Pipe处理包含以下的几个过程,\n",
|
||||
"1. 将raw_words或 raw_chars进行tokenize以切分成不同的词或字; \n",
|
||||
"2. 再建立词或字的 Vocabulary , 并将词或字转换为index; \n",
|
||||
"3. 将target 列建立词表并将target列转为index;\n",
|
||||
"\n",
|
||||
"所有的Pipe都可通过其文档查看该Pipe支持处理的 DataSet 以及返回的 DataBundle 中的Vocabulary的情况; 如 OntoNotesNERPipe\n",
|
||||
"\n",
|
||||
"各种数据集的Pipe当中,都包含了以下的两个函数:\n",
|
||||
"\n",
|
||||
"- process() 函数:对输入的 DataBundle 进行处理, 然后返回处理之后的 DataBundle 。process函数的文档中包含了该Pipe支持处理的DataSet的格式。\n",
|
||||
"- process_from_file() 函数:输入数据集所在文件夹,使用对应的Loader读取数据(所以该函数支持的参数类型是由于其对应的Loader的load函数决定的),然后调用相对应的process函数对数据进行预处理。相当于是把Load和process放在一个函数中执行。\n",
|
||||
"\n",
|
||||
"接着上面 CWSLoader 的例子,我们展示一下 CWSPipe 的功能:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In total 3 datasets:\n",
|
||||
"\ttest has 1944 instances.\n",
|
||||
"\ttrain has 17196 instances.\n",
|
||||
"\tdev has 1858 instances.\n",
|
||||
"In total 2 vocabs:\n",
|
||||
"\tchars has 4777 entries.\n",
|
||||
"\ttarget has 4 entries.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.io import CWSPipe\n",
|
||||
"\n",
|
||||
"data_bundle = CWSPipe().process(data_bundle)\n",
|
||||
"print(data_bundle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"表示一共有3个数据集和2个词表。其中:\n",
|
||||
"\n",
|
||||
"- 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n",
|
||||
"- 2个词表分别为chars词表与target词表。其中chars词表为句子文本所构建的词表,一共有4777个不同的字;target词表为目标标签所构建的词表,一共有4种标签。\n",
|
||||
"\n",
|
||||
"相较于之前CWSLoader读取的DataBundle,新增了两个Vocabulary。 我们可以打印一下处理之后的DataSet"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+---------------------+---------------------+---------------------+---------+\n",
|
||||
"| raw_words | chars | target | seq_len |\n",
|
||||
"+---------------------+---------------------+---------------------+---------+\n",
|
||||
"| 迈向 充满 希望... | [1224, 178, 674,... | [0, 1, 0, 1, 0, ... | 29 |\n",
|
||||
"| 中共中央 总书记... | [11, 212, 11, 33... | [0, 3, 3, 1, 0, ... | 15 |\n",
|
||||
"+---------------------+---------------------+---------------------+---------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tr_data = data_bundle.get_dataset('train')\n",
|
||||
"print(tr_data[:2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"可以看到有两列为int的field: chars和target。这两列的名称同时也是DataBundle中的Vocabulary的名称。可以通过下列的代码获取并查看Vocabulary的 信息"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Vocabulary(['B', 'E', 'S', 'M']...)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vocab = data_bundle.get_vocab('target')\n",
|
||||
"print(vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Part IV: fastNLP封装好的Loader和Pipe\n",
|
||||
"\n",
|
||||
"fastNLP封装了多种任务/数据集的 Loader 和 Pipe 并提供自动下载功能,具体参见文档 [数据集](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)\n",
|
||||
"\n",
|
||||
"## Part V: 不同格式类型的基础Loader\n",
|
||||
"\n",
|
||||
"除了上面提到的针对具体任务的Loader,我们还提供了CSV格式和JSON格式的Loader\n",
|
||||
"\n",
|
||||
"**CSVLoader** 读取CSV类型的数据集文件。例子如下:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from fastNLP.io.loader import CSVLoader\n",
|
||||
"data_set_loader = CSVLoader(\n",
|
||||
" headers=('raw_words', 'target'), sep='\\t'\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"表示将CSV文件中每一行的第一项将填入'raw_words' field,第二项填入'target' field。其中项之间由'\\t'分割开来\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"data_set = data_set_loader._load('path/to/your/file')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"文件内容样例如下\n",
|
||||
"\n",
|
||||
"```csv\n",
|
||||
"But it does not leave you with much . 1\n",
|
||||
"You could hate it for the same reason . 1\n",
|
||||
"The performances are an absolute joy . 4\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"读取之后的DataSet具有以下的field\n",
|
||||
"\n",
|
||||
"| raw_words | target |\n",
|
||||
"| --------------------------------------- | ------ |\n",
|
||||
"| But it does not leave you with much . | 1 |\n",
|
||||
"| You could hate it for the same reason . | 1 |\n",
|
||||
"| The performances are an absolute joy . | 4 |\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**JsonLoader** 读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from fastNLP.io.loader import JsonLoader\n",
|
||||
"loader = JsonLoader(\n",
|
||||
" fields={'sentence1': 'raw_words1', 'sentence2': 'raw_words2', 'gold_label': 'target'}\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'raw_words1'、'raw_words2'、'target'这三个fields\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"data_set = loader._load('path/to/your/file')\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"数据集内容样例如下\n",
|
||||
"```\n",
|
||||
"{\"annotator_labels\": [\"neutral\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"neutral\", ... }\n",
|
||||
"{\"annotator_labels\": [\"contradiction\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"contradiction\", ... }\n",
|
||||
"{\"annotator_labels\": [\"entailment\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"entailment\", ... }\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"读取之后的DataSet具有以下的field\n",
|
||||
"\n",
|
||||
"| raw_words0 | raw_words1 | target |\n",
|
||||
"| ------------------------------------------------------ | ------------------------------------------------- | ------------- |\n",
|
||||
"| A person on a horse jumps over a broken down airplane. | A person is training his horse for a competition. | neutral |\n",
|
||||
"| A person on a horse jumps over a broken down airplane. | A person is at a diner, ordering an omelette. | contradiction |\n",
|
||||
"| A person on a horse jumps over a broken down airplane. | A person is outdoors, on a horse. | entailment |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,603 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 使用Trainer和Tester快速训练和测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 数据读入和处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In total 3 datasets:\n",
|
||||
"\ttest has 1821 instances.\n",
|
||||
"\ttrain has 67349 instances.\n",
|
||||
"\tdev has 872 instances.\n",
|
||||
"In total 2 vocabs:\n",
|
||||
"\twords has 16292 entries.\n",
|
||||
"\ttarget has 2 entries.\n",
|
||||
"\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"| raw_words | target | words | seq_len |\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.io import SST2Pipe\n",
|
||||
"\n",
|
||||
"pipe = SST2Pipe()\n",
|
||||
"databundle = pipe.process_from_file()\n",
|
||||
"vocab = databundle.get_vocab('words')\n",
|
||||
"print(databundle)\n",
|
||||
"print(databundle.get_dataset('train')[0])\n",
|
||||
"print(databundle.get_vocab('words'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"4925 872 75\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_data = databundle.get_dataset('train')[:5000]\n",
|
||||
"train_data, test_data = train_data.split(0.015)\n",
|
||||
"dev_data = databundle.get_dataset('dev')\n",
|
||||
"print(len(train_data),len(dev_data),len(test_data))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+-------------+-----------+--------+-------+---------+\n",
|
||||
"| field_names | raw_words | target | words | seq_len |\n",
|
||||
"+-------------+-----------+--------+-------+---------+\n",
|
||||
"| is_input | False | False | True | True |\n",
|
||||
"| is_target | False | True | False | False |\n",
|
||||
"| ignore_type | | False | False | False |\n",
|
||||
"| pad_value | | 0 | 0 | 0 |\n",
|
||||
"+-------------+-----------+--------+-------+---------+\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<prettytable.PrettyTable at 0x7f49ec540160>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_data.print_field_meta()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 使用内置模型训练"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.models import CNNText\n",
|
||||
"\n",
|
||||
"#词嵌入的维度\n",
|
||||
"EMBED_DIM = 100\n",
|
||||
"\n",
|
||||
"#使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数\n",
|
||||
"#还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值\n",
|
||||
"model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=2, dropout=0.1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import AccuracyMetric\n",
|
||||
"from fastNLP import Const\n",
|
||||
"\n",
|
||||
"# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n",
|
||||
"metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import CrossEntropyLoss\n",
|
||||
"\n",
|
||||
"# loss = CrossEntropyLoss() 在本例中与下面这行代码等价\n",
|
||||
"loss = CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field\n",
|
||||
"# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数\n",
|
||||
"# 传入func作为一个名为`target`的参数\n",
|
||||
"#下面自己构建了一个交叉熵函数,和之后直接使用fastNLP中的交叉熵函数是一个效果\n",
|
||||
"import torch\n",
|
||||
"from fastNLP import LossFunc\n",
|
||||
"func = torch.nn.functional.cross_entropy\n",
|
||||
"loss_func = LossFunc(func, input=Const.OUTPUT, target=Const.TARGET)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"#使用 torch.optim 定义优化器\n",
|
||||
"optimizer=optim.RMSprop(model_cnn.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"input fields after batch(if batch size is 2):\n",
|
||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n",
|
||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"target fields after batch(if batch size is 2):\n",
|
||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"\n",
|
||||
"training epochs started 2020-02-27-11-31-25\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=3080.0), HTML(value='')), layout=Layout(d…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.75 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 1/10. Step:308/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.751147\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.83 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 2/10. Step:616/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.755734\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 1.32 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 3/10. Step:924/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.758028\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.88 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 4/10. Step:1232/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.741972\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.96 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 5/10. Step:1540/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.728211\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.87 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 6/10. Step:1848/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.755734\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 1.04 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 7/10. Step:2156/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.732798\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.57 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 8/10. Step:2464/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.747706\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.48 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 9/10. Step:2772/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.732798\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.48 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 10/10. Step:3080/3080: \n",
|
||||
"\r",
|
||||
"AccuracyMetric: acc=0.740826\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"In Epoch:3/Step:924, got best dev performance:\n",
|
||||
"AccuracyMetric: acc=0.758028\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'best_eval': {'AccuracyMetric': {'acc': 0.758028}},\n",
|
||||
" 'best_epoch': 3,\n",
|
||||
" 'best_step': 924,\n",
|
||||
" 'seconds': 160.58}"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Trainer\n",
|
||||
"\n",
|
||||
"#训练的轮数和batch size\n",
|
||||
"N_EPOCHS = 10\n",
|
||||
"BATCH_SIZE = 16\n",
|
||||
"\n",
|
||||
"#如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3\n",
|
||||
"#这里只使用了loss作为损失函数输入,感兴趣可以尝试其他损失函数(如之前自定义的loss_func)作为输入\n",
|
||||
"trainer = Trainer(model=model_cnn, train_data=train_data, dev_data=dev_data, loss=loss, metrics=metrics,\n",
|
||||
"optimizer=optimizer,n_epochs=N_EPOCHS, batch_size=BATCH_SIZE)\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.43 seconds!\n",
|
||||
"[tester] \n",
|
||||
"AccuracyMetric: acc=0.773333\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'AccuracyMetric': {'acc': 0.773333}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Tester\n",
|
||||
"\n",
|
||||
"tester = Tester(test_data, model_cnn, metrics=AccuracyMetric())\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,681 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 使用Trainer和Tester快速训练和测试"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 数据读入和处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In total 3 datasets:\n",
|
||||
"\ttest has 1821 instances.\n",
|
||||
"\ttrain has 67349 instances.\n",
|
||||
"\tdev has 872 instances.\n",
|
||||
"In total 2 vocabs:\n",
|
||||
"\twords has 16292 entries.\n",
|
||||
"\ttarget has 2 entries.\n",
|
||||
"\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"| raw_words | target | words | seq_len |\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n",
|
||||
"+-----------------------------------+--------+-----------------------------------+---------+\n",
|
||||
"Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.io import SST2Pipe\n",
|
||||
"\n",
|
||||
"pipe = SST2Pipe()\n",
|
||||
"databundle = pipe.process_from_file()\n",
|
||||
"vocab = databundle.get_vocab('words')\n",
|
||||
"print(databundle)\n",
|
||||
"print(databundle.get_dataset('train')[0])\n",
|
||||
"print(databundle.get_vocab('words'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"4925 872 75\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_data = databundle.get_dataset('train')[:5000]\n",
|
||||
"train_data, test_data = train_data.split(0.015)\n",
|
||||
"dev_data = databundle.get_dataset('dev')\n",
|
||||
"print(len(train_data),len(dev_data),len(test_data))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+-------------+-----------+--------+-------+---------+\n",
|
||||
"| field_names | raw_words | target | words | seq_len |\n",
|
||||
"+-------------+-----------+--------+-------+---------+\n",
|
||||
"| is_input | False | False | True | True |\n",
|
||||
"| is_target | False | True | False | False |\n",
|
||||
"| ignore_type | | False | False | False |\n",
|
||||
"| pad_value | | 0 | 0 | 0 |\n",
|
||||
"+-------------+-----------+--------+-------+---------+\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<prettytable.PrettyTable at 0x7f0db03d0640>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_data.print_field_meta()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import AccuracyMetric\n",
|
||||
"from fastNLP import Const\n",
|
||||
"\n",
|
||||
"# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n",
|
||||
"metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## DataSetIter初探"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
|
||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
|
||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
|
||||
" 1323, 4398, 7],\n",
|
||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
|
||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
|
||||
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
|
||||
"batch_y: {'target': tensor([1, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
|
||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
|
||||
"batch_y: {'target': tensor([0, 1])}\n",
|
||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
|
||||
" [15618, 3204, 5, 1675, 0]]), 'seq_len': tensor([5, 4])}\n",
|
||||
"batch_y: {'target': tensor([1, 1])}\n",
|
||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
|
||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
|
||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
|
||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
|
||||
"batch_y: {'target': tensor([0, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
|
||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
|
||||
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
|
||||
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 12])}\n",
|
||||
"batch_y: {'target': tensor([0, 1])}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import BucketSampler\n",
|
||||
"from fastNLP import DataSetIter\n",
|
||||
"\n",
|
||||
"tmp_data = dev_data[:10]\n",
|
||||
"# 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n",
|
||||
"# 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n",
|
||||
"sampler = BucketSampler(batch_size=2, seq_len_field_name='seq_len')\n",
|
||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
|
||||
"for batch_x, batch_y in batch:\n",
|
||||
" print(\"batch_x: \",batch_x)\n",
|
||||
" print(\"batch_y: \", batch_y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
|
||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
|
||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
|
||||
" 1323, 4398, 7],\n",
|
||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
|
||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
|
||||
" 7, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
|
||||
" -1, -1, -1]]), 'seq_len': tensor([33, 21])}\n",
|
||||
"batch_y: {'target': tensor([1, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n",
|
||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n",
|
||||
"batch_y: {'target': tensor([0, 1])}\n",
|
||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
|
||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n",
|
||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
|
||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n",
|
||||
"batch_y: {'target': tensor([0, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n",
|
||||
" [15618, 3204, 5, 1675, -1]]), 'seq_len': tensor([5, 4])}\n",
|
||||
"batch_y: {'target': tensor([1, 1])}\n",
|
||||
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
|
||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n",
|
||||
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
|
||||
" 1217, 7, -1, -1, -1, -1, -1, -1, -1, -1]]), 'seq_len': tensor([20, 12])}\n",
|
||||
"batch_y: {'target': tensor([0, 1])}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tmp_data.set_pad_val('words',-1)\n",
|
||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
|
||||
"for batch_x, batch_y in batch:\n",
|
||||
" print(\"batch_x: \",batch_x)\n",
|
||||
" print(\"batch_y: \", batch_y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"batch_x: {'words': tensor([[ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n",
|
||||
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n",
|
||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([12, 20])}\n",
|
||||
"batch_y: {'target': tensor([1, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n",
|
||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n",
|
||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n",
|
||||
" 1323, 4398, 7, 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n",
|
||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n",
|
||||
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n",
|
||||
"batch_y: {'target': tensor([1, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0],\n",
|
||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0]]), 'seq_len': tensor([9, 9])}\n",
|
||||
"batch_y: {'target': tensor([0, 1])}\n",
|
||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n",
|
||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n",
|
||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 20])}\n",
|
||||
"batch_y: {'target': tensor([0, 0])}\n",
|
||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
||||
" [15618, 3204, 5, 1675, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([5, 4])}\n",
|
||||
"batch_y: {'target': tensor([1, 1])}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.core.field import Padder\n",
|
||||
"import numpy as np\n",
|
||||
"class FixLengthPadder(Padder):\n",
|
||||
" def __init__(self, pad_val=0, length=None):\n",
|
||||
" super().__init__(pad_val=pad_val)\n",
|
||||
" self.length = length\n",
|
||||
" assert self.length is not None, \"Creating FixLengthPadder with no specific length!\"\n",
|
||||
"\n",
|
||||
" def __call__(self, contents, field_name, field_ele_dtype, dim):\n",
|
||||
" #计算当前contents中的最大长度\n",
|
||||
" max_len = max(map(len, contents))\n",
|
||||
" #如果当前contents中的最大长度大于指定的padder length的话就报错\n",
|
||||
" assert max_len <= self.length, \"Fixed padder length smaller than actual length! with length {}\".format(max_len)\n",
|
||||
" array = np.full((len(contents), self.length), self.pad_val, dtype=field_ele_dtype)\n",
|
||||
" for i, content_i in enumerate(contents):\n",
|
||||
" array[i, :len(content_i)] = content_i\n",
|
||||
" return array\n",
|
||||
"\n",
|
||||
"#设定FixLengthPadder的固定长度为40\n",
|
||||
"tmp_padder = FixLengthPadder(pad_val=0,length=40)\n",
|
||||
"#利用dataset的set_padder函数设定words field的padder\n",
|
||||
"tmp_data.set_padder('words',tmp_padder)\n",
|
||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n",
|
||||
"for batch_x, batch_y in batch:\n",
|
||||
" print(\"batch_x: \",batch_x)\n",
|
||||
" print(\"batch_y: \", batch_y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 使用DataSetIter自己编写训练过程\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-----start training-----\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 2.68 seconds!\n",
|
||||
"Epoch 0 Avg Loss: 0.66 AccuracyMetric: acc=0.708716 29307ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.38 seconds!\n",
|
||||
"Epoch 1 Avg Loss: 0.41 AccuracyMetric: acc=0.770642 52200ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.51 seconds!\n",
|
||||
"Epoch 2 Avg Loss: 0.16 AccuracyMetric: acc=0.747706 70268ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.96 seconds!\n",
|
||||
"Epoch 3 Avg Loss: 0.06 AccuracyMetric: acc=0.741972 90349ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 1.04 seconds!\n",
|
||||
"Epoch 4 Avg Loss: 0.03 AccuracyMetric: acc=0.740826 114250ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.8 seconds!\n",
|
||||
"Epoch 5 Avg Loss: 0.02 AccuracyMetric: acc=0.738532 134742ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.65 seconds!\n",
|
||||
"Epoch 6 Avg Loss: 0.01 AccuracyMetric: acc=0.731651 154503ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.8 seconds!\n",
|
||||
"Epoch 7 Avg Loss: 0.01 AccuracyMetric: acc=0.738532 175397ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.36 seconds!\n",
|
||||
"Epoch 8 Avg Loss: 0.01 AccuracyMetric: acc=0.733945 192384ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.84 seconds!\n",
|
||||
"Epoch 9 Avg Loss: 0.01 AccuracyMetric: acc=0.744266 214417ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.04 seconds!\n",
|
||||
"[tester] \n",
|
||||
"AccuracyMetric: acc=0.786667\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'AccuracyMetric': {'acc': 0.786667}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import BucketSampler\n",
|
||||
"from fastNLP import DataSetIter\n",
|
||||
"from fastNLP.models import CNNText\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"import torch\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"embed_dim = 100\n",
|
||||
"model = CNNText((len(vocab),embed_dim), num_classes=2, dropout=0.1)\n",
|
||||
"\n",
|
||||
"def train(epoch, data, devdata):\n",
|
||||
" optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
" lossfunc = torch.nn.CrossEntropyLoss()\n",
|
||||
" batch_size = 32\n",
|
||||
"\n",
|
||||
" # 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n",
|
||||
" # 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n",
|
||||
" train_sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')\n",
|
||||
" train_batch = DataSetIter(batch_size=batch_size, dataset=data, sampler=train_sampler)\n",
|
||||
"\n",
|
||||
" start_time = time.time()\n",
|
||||
" print(\"-\"*5+\"start training\"+\"-\"*5)\n",
|
||||
" for i in range(epoch):\n",
|
||||
" loss_list = []\n",
|
||||
" for batch_x, batch_y in train_batch:\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(batch_x['words'])\n",
|
||||
" loss = lossfunc(output['pred'], batch_y['target'])\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" loss_list.append(loss.item())\n",
|
||||
"\n",
|
||||
" #这里verbose如果为0,在调用Tester对象的test()函数时不输出任何信息,返回评估信息; 如果为1,打印出验证结果,返回评估信息\n",
|
||||
" #在调用过Tester对象的test()函数后,调用其_format_eval_results(res)函数,结构化输出验证结果\n",
|
||||
" tester_tmp = Tester(devdata, model, metrics=AccuracyMetric(), verbose=0)\n",
|
||||
" res=tester_tmp.test()\n",
|
||||
"\n",
|
||||
" print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)),end=\" \")\n",
|
||||
" print(tester_tmp._format_eval_results(res),end=\" \")\n",
|
||||
" print('{:d}ms'.format(round((time.time()-start_time)*1000)))\n",
|
||||
" loss_list.clear()\n",
|
||||
"\n",
|
||||
"train(10, train_data, dev_data)\n",
|
||||
"#使用tester进行快速测试\n",
|
||||
"tester = Tester(test_data, model, metrics=AccuracyMetric())\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,622 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 使用 Callback 自定义你的训练过程"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- 什么是 Callback\n",
|
||||
"- 使用 Callback \n",
|
||||
"- 一些常用的 Callback\n",
|
||||
"- 自定义实现 Callback"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"什么是Callback\n",
|
||||
"------\n",
|
||||
"\n",
|
||||
"Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
|
||||
"\n",
|
||||
"fastNLP 中提供了很多常用的 Callback ,开箱即用。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"使用 Callback\n",
|
||||
" ------\n",
|
||||
"\n",
|
||||
"使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-09-17T07:34:46.465871Z",
|
||||
"start_time": "2019-09-17T07:34:30.648758Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In total 3 datasets:\n",
|
||||
"\ttest has 1200 instances.\n",
|
||||
"\ttrain has 9600 instances.\n",
|
||||
"\tdev has 1200 instances.\n",
|
||||
"In total 2 vocabs:\n",
|
||||
"\tchars has 4409 entries.\n",
|
||||
"\ttarget has 2 entries.\n",
|
||||
"\n",
|
||||
"training epochs started 2019-09-17-03-34-34\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.1 seconds!\n",
|
||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
|
||||
"AccuracyMetric: acc=0.863333\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.11 seconds!\n",
|
||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
|
||||
"AccuracyMetric: acc=0.886667\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.1 seconds!\n",
|
||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
|
||||
"AccuracyMetric: acc=0.890833\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"In Epoch:3/Step:900, got best dev performance:\n",
|
||||
"AccuracyMetric: acc=0.890833\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import (Callback, EarlyStopCallback,\n",
|
||||
" Trainer, CrossEntropyLoss, AccuracyMetric)\n",
|
||||
"from fastNLP.models import CNNText\n",
|
||||
"import torch.cuda\n",
|
||||
"\n",
|
||||
"# prepare data\n",
|
||||
"def get_data():\n",
|
||||
" from fastNLP.io import ChnSentiCorpPipe as pipe\n",
|
||||
" data = pipe().process_from_file()\n",
|
||||
" print(data)\n",
|
||||
" data.rename_field('chars', 'words')\n",
|
||||
" train_data = data.datasets['train']\n",
|
||||
" dev_data = data.datasets['dev']\n",
|
||||
" test_data = data.datasets['test']\n",
|
||||
" vocab = data.vocabs['words']\n",
|
||||
" tgt_vocab = data.vocabs['target']\n",
|
||||
" return train_data, dev_data, test_data, vocab, tgt_vocab\n",
|
||||
"\n",
|
||||
"# prepare model\n",
|
||||
"train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
|
||||
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
|
||||
"model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
|
||||
"\n",
|
||||
"# define callback\n",
|
||||
"callbacks=[EarlyStopCallback(5)]\n",
|
||||
"\n",
|
||||
"# pass callbacks to Trainer\n",
|
||||
"def train_with_callback(cb_list):\n",
|
||||
" trainer = Trainer(\n",
|
||||
" device=device,\n",
|
||||
" n_epochs=3,\n",
|
||||
" model=model, \n",
|
||||
" train_data=train_data, \n",
|
||||
" dev_data=dev_data, \n",
|
||||
" loss=CrossEntropyLoss(), \n",
|
||||
" metrics=AccuracyMetric(), \n",
|
||||
" callbacks=cb_list, \n",
|
||||
" check_code_level=-1\n",
|
||||
" )\n",
|
||||
" trainer.train()\n",
|
||||
"\n",
|
||||
"train_with_callback(callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"fastNLP 中的 Callback\n",
|
||||
"-------\n",
|
||||
"fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-09-17T07:35:02.182727Z",
|
||||
"start_time": "2019-09-17T07:34:49.443863Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"training epochs started 2019-09-17-03-34-49\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.13 seconds!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.12 seconds!\n",
|
||||
"Evaluation on data-test:\n",
|
||||
"AccuracyMetric: acc=0.890833\n",
|
||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
|
||||
"AccuracyMetric: acc=0.890833\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.09 seconds!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.09 seconds!\n",
|
||||
"Evaluation on data-test:\n",
|
||||
"AccuracyMetric: acc=0.8875\n",
|
||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
|
||||
"AccuracyMetric: acc=0.8875\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.11 seconds!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.1 seconds!\n",
|
||||
"Evaluation on data-test:\n",
|
||||
"AccuracyMetric: acc=0.885\n",
|
||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
|
||||
"AccuracyMetric: acc=0.885\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"In Epoch:1/Step:300, got best dev performance:\n",
|
||||
"AccuracyMetric: acc=0.890833\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
|
||||
"callbacks = [\n",
|
||||
" EarlyStopCallback(5),\n",
|
||||
" GradientClipCallback(clip_value=5, clip_type='value'),\n",
|
||||
" EvaluateCallback(dev_data)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"train_with_callback(callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"自定义 Callback\n",
|
||||
"------\n",
|
||||
"\n",
|
||||
"这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n",
|
||||
"\n",
|
||||
"#### 创建 Callback\n",
|
||||
" \n",
|
||||
"要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n",
|
||||
"\n",
|
||||
"这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
|
||||
"\n",
|
||||
"#### 指定 Callback 调用的阶段\n",
|
||||
" \n",
|
||||
"Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
|
||||
"\n",
|
||||
"这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
|
||||
"\n",
|
||||
"#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
|
||||
" \n",
|
||||
"为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
|
||||
"\n",
|
||||
"这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-09-17T07:43:10.907139Z",
|
||||
"start_time": "2019-09-17T07:42:58.488177Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"training epochs started 2019-09-17-03-42-58\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.11 seconds!\n",
|
||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
|
||||
"AccuracyMetric: acc=0.883333\n",
|
||||
"\n",
|
||||
"Avg loss at epoch 1, 0.100254\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.1 seconds!\n",
|
||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
|
||||
"AccuracyMetric: acc=0.8775\n",
|
||||
"\n",
|
||||
"Avg loss at epoch 2, 0.183511\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 0.13 seconds!\n",
|
||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
|
||||
"AccuracyMetric: acc=0.875833\n",
|
||||
"\n",
|
||||
"Avg loss at epoch 3, 0.257103\n",
|
||||
"\r\n",
|
||||
"In Epoch:1/Step:300, got best dev performance:\n",
|
||||
"AccuracyMetric: acc=0.883333\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Callback\n",
|
||||
"from fastNLP import logger\n",
|
||||
"\n",
|
||||
"class MyCallBack(Callback):\n",
|
||||
" \"\"\"Print average loss in each epoch\"\"\"\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.total_loss = 0\n",
|
||||
" self.start_step = 0\n",
|
||||
" \n",
|
||||
" def on_backward_begin(self, loss):\n",
|
||||
" self.total_loss += loss.item()\n",
|
||||
" \n",
|
||||
" def on_epoch_end(self):\n",
|
||||
" n_steps = self.step - self.start_step\n",
|
||||
" avg_loss = self.total_loss / n_steps\n",
|
||||
" logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
|
||||
" self.start_step = self.step\n",
|
||||
"\n",
|
||||
"callbacks = [MyCallBack()]\n",
|
||||
"train_with_callback(callbacks)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
},
|
||||
"varInspector": {
|
||||
"cols": {
|
||||
"lenName": 16,
|
||||
"lenType": 16,
|
||||
"lenVar": 40
|
||||
},
|
||||
"kernels_config": {
|
||||
"python": {
|
||||
"delete_cmd_postfix": "",
|
||||
"delete_cmd_prefix": "del ",
|
||||
"library": "var_list.py",
|
||||
"varRefreshCmd": "print(var_dic_list())"
|
||||
},
|
||||
"r": {
|
||||
"delete_cmd_postfix": ") ",
|
||||
"delete_cmd_prefix": "rm(",
|
||||
"library": "var_list.r",
|
||||
"varRefreshCmd": "cat(var_dic_list()) "
|
||||
}
|
||||
},
|
||||
"types_to_exclude": [
|
||||
"module",
|
||||
"function",
|
||||
"builtin_function_or_method",
|
||||
"instance",
|
||||
"_Feature"
|
||||
],
|
||||
"window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -1,912 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 序列标注\n",
|
||||
"\n",
|
||||
"这一部分的内容主要展示如何使用fastNLP实现序列标注(Sequence labeling)任务。您可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 在阅读这篇教程前,希望您已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让您进一步熟悉fastNLP的使用。\n",
|
||||
"\n",
|
||||
"## 命名实体识别(name entity recognition, NER)\n",
|
||||
"\n",
|
||||
"命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体,通常包括人名、地名、机构名和时间等。 如下面的例子中\n",
|
||||
"\n",
|
||||
"*我来自复旦大学*\n",
|
||||
"\n",
|
||||
"其中“复旦大学”就是一个机构名,命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体,且属于机构名这个类别。这个问题在实际做的时候会被 转换为序列标注问题\n",
|
||||
"\n",
|
||||
"针对\"我来自复旦大学\"这句话,我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG],其中O表示out,即不是一个实体,B-ORG是ORG( organization的缩写)这个类别的开头(Begin),I-ORG是ORG类别的中间(Inside)。\n",
|
||||
"\n",
|
||||
"在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。\n",
|
||||
"\n",
|
||||
"## 载入数据\n",
|
||||
"\n",
|
||||
"fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您可以通过《使用Loader和Pipe处理数据》了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
|
||||
"| raw_chars | target | chars | seq_len |\n",
|
||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n",
|
||||
"| ['科', '技', '全', '方', '位',... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... | [792, 1015, 156, 198, 291, 714... | 26 |\n",
|
||||
"| ['对', ',', '输', '给', '一',... | [0, 0, 0, 0, 0, 0, 3, 1, 0, 0,... | [123, 2, 1205, 115, 8, 24, 101... | 15 |\n",
|
||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.io import WeiboNERPipe\n",
|
||||
"data_bundle = WeiboNERPipe().process_from_file()\n",
|
||||
"print(data_bundle.get_dataset('train')[:2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 模型构建\n",
|
||||
"\n",
|
||||
"首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见《使用Embedding模块将文本转成向量》。 在这里我们使用通过word2vec预训练的中文汉字embedding。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 3321 out of 3471 words in the pre-training embedding.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"\n",
|
||||
"embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"选择好Embedding之后,我们可以使用fastNLP中自带的 fastNLP.models.BiLSTMCRF 作为模型。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.models import BiLSTMCRF\n",
|
||||
"\n",
|
||||
"data_bundle.rename_field('chars', 'words') # 这是由于BiLSTMCRF模型的forward函数接受的words,而不是chars,所以需要把这一列重新命名\n",
|
||||
"model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
|
||||
" target_vocab=data_bundle.get_vocab('target'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 进行训练\n",
|
||||
"下面我们选择用来评估模型的metric,以及优化用到的优化函数。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import SpanFPreRecMetric\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from fastNLP import LossInForward\n",
|
||||
"\n",
|
||||
"metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
|
||||
"optimizer = Adam(model.parameters(), lr=1e-2)\n",
|
||||
"loss = LossInForward()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"使用Trainer进行训练, 您可以通过修改 device 的值来选择显卡。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"input fields after batch(if batch size is 2):\n",
|
||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"target fields after batch(if batch size is 2):\n",
|
||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"\n",
|
||||
"training epochs started 2020-02-27-13-53-24\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=430.0), HTML(value='')), layout=Layout(di…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.89 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 1/10. Step:43/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.067797, pre=0.192771, rec=0.041131\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.9 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 2/10. Step:86/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.344086, pre=0.568047, rec=0.246787\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.88 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 3/10. Step:129/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.446701, pre=0.653465, rec=0.339332\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.81 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 4/10. Step:172/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.479871, pre=0.642241, rec=0.383033\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.91 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 5/10. Step:215/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.486312, pre=0.650862, rec=0.388175\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.87 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 6/10. Step:258/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.86 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 7/10. Step:301/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.430335, pre=0.685393, rec=0.313625\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.82 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 8/10. Step:344/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.477759, pre=0.665138, rec=0.372751\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.81 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 9/10. Step:387/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.500759, pre=0.611111, rec=0.424165\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 0.8 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 10/10. Step:430/430: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.496025, pre=0.65, rec=0.401028\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"In Epoch:6/Step:258, got best dev performance:\n",
|
||||
"SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'best_eval': {'SpanFPreRecMetric': {'f': 0.541401,\n",
|
||||
" 'pre': 0.711297,\n",
|
||||
" 'rec': 0.437018}},\n",
|
||||
" 'best_epoch': 6,\n",
|
||||
" 'best_step': 258,\n",
|
||||
" 'seconds': 121.39}"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Trainer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"device= 0 if torch.cuda.is_available() else 'cpu'\n",
|
||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n",
|
||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 进行测试\n",
|
||||
"训练结束之后过,可以通过 Tester 测试其在测试集上的性能"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 1.54 seconds!\n",
|
||||
"[tester] \n",
|
||||
"SpanFPreRecMetric: f=0.439024, pre=0.685279, rec=0.322967\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'SpanFPreRecMetric': {'f': 0.439024, 'pre': 0.685279, 'rec': 0.322967}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from fastNLP import Tester\n",
|
||||
"tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 使用更强的Bert做序列标注\n",
|
||||
"\n",
|
||||
"在fastNLP使用Bert进行任务,您只需要把fastNLP.embeddings.StaticEmbedding 切换为 fastNLP.embeddings.BertEmbedding(可修改 device 选择显卡)。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
|
||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
|
||||
"Start to generate word pieces for word.\n",
|
||||
"Found(Or segment into word pieces) 3384 words out of 3471.\n",
|
||||
"input fields after batch(if batch size is 2):\n",
|
||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"target fields after batch(if batch size is 2):\n",
|
||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n",
|
||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
|
||||
"\n",
|
||||
"training epochs started 2020-02-27-13-58-51\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1130.0), HTML(value='')), layout=Layout(d…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.7 seconds!\n",
|
||||
"Evaluation on dev at Epoch 1/10. Step:113/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.008114, pre=0.019231, rec=0.005141\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.49 seconds!\n",
|
||||
"Evaluation on dev at Epoch 2/10. Step:226/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.467866, pre=0.467866, rec=0.467866\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.6 seconds!\n",
|
||||
"Evaluation on dev at Epoch 3/10. Step:339/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.566879, pre=0.482821, rec=0.686375\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.56 seconds!\n",
|
||||
"Evaluation on dev at Epoch 4/10. Step:452/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.651972, pre=0.59408, rec=0.722365\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 2.69 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 5/10. Step:565/1130: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.640909, pre=0.574338, rec=0.724936\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.52 seconds!\n",
|
||||
"Evaluation on dev at Epoch 6/10. Step:678/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.661836, pre=0.624146, rec=0.70437\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.67 seconds!\n",
|
||||
"Evaluation on dev at Epoch 7/10. Step:791/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.683429, pre=0.615226, rec=0.768638\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 2.37 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 8/10. Step:904/1130: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.674699, pre=0.634921, rec=0.719794\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluate data in 2.42 seconds!\n",
|
||||
"Evaluation on dev at Epoch 9/10. Step:1017/1130: \n",
|
||||
"SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 2.46 seconds!\n",
|
||||
"\r",
|
||||
"Evaluation on dev at Epoch 10/10. Step:1130/1130: \n",
|
||||
"\r",
|
||||
"SpanFPreRecMetric: f=0.686845, pre=0.62766, rec=0.758355\n",
|
||||
"\n",
|
||||
"\r\n",
|
||||
"In Epoch:9/Step:1017, got best dev performance:\n",
|
||||
"SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n",
|
||||
"Reloaded the best model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"Evaluate data in 1.96 seconds!\n",
|
||||
"[tester] \n",
|
||||
"SpanFPreRecMetric: f=0.626561, pre=0.596112, rec=0.660287\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'SpanFPreRecMetric': {'f': 0.626561, 'pre': 0.596112, 'rec': 0.660287}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"from fastNLP.io import WeiboNERPipe\n",
|
||||
"data_bundle = WeiboNERPipe().process_from_file()\n",
|
||||
"data_bundle.rename_field('chars', 'words')\n",
|
||||
"\n",
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"embed = BertEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='cn')\n",
|
||||
"model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n",
|
||||
" target_vocab=data_bundle.get_vocab('target'))\n",
|
||||
"\n",
|
||||
"from fastNLP import SpanFPreRecMetric\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from fastNLP import LossInForward\n",
|
||||
"metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n",
|
||||
"optimizer = Adam(model.parameters(), lr=2e-5)\n",
|
||||
"loss = LossInForward()\n",
|
||||
"\n",
|
||||
"from fastNLP import Trainer\n",
|
||||
"import torch\n",
|
||||
"device= 5 if torch.cuda.is_available() else 'cpu'\n",
|
||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, batch_size=12,\n",
|
||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python Now",
|
||||
"language": "python",
|
||||
"name": "now"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,564 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 文本分类(Text classification)\n",
|
||||
"文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n",
|
||||
"\n",
|
||||
"Example:: \n",
|
||||
"1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n",
|
||||
"\n",
|
||||
"数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"![jupyter](./cn_cls_example.png)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 步骤\n",
|
||||
"一共有以下的几个步骤 \n",
|
||||
"(1) 读取数据 \n",
|
||||
"(2) 预处理数据 \n",
|
||||
"(3) 选择预训练词向量 \n",
|
||||
"(4) 创建模型 \n",
|
||||
"(5) 训练模型 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (1) 读取数据\n",
|
||||
"fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.io import ChnSentiCorpLoader\n",
|
||||
"\n",
|
||||
"loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
|
||||
"data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
|
||||
"data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(data_bundle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (2) 预处理数据\n",
|
||||
"在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n",
|
||||
"\n",
|
||||
"fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.io import ChnSentiCorpPipe\n",
|
||||
"\n",
|
||||
"pipe = ChnSentiCorpPipe()\n",
|
||||
"data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(data_bundle) # 打印data_bundle,查看其变化"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(data_bundle.get_dataset('train')[:2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"char_vocab = data_bundle.get_vocab('chars')\n",
|
||||
"print(char_vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Vocabulary是一个记录着词语与index之间映射关系的类,比如"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index = char_vocab.to_index('选')\n",
|
||||
"print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
|
||||
"print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (3) 选择预训练词向量 \n",
|
||||
"由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"\n",
|
||||
"word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (4) 创建模型\n",
|
||||
"这里我们使用到的模型结构如下所示,补图"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"from fastNLP.modules import LSTM\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# 定义模型\n",
|
||||
"class BiLSTMMaxPoolCls(nn.Module):\n",
|
||||
" def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
|
||||
" super().__init__()\n",
|
||||
" self.embed = embed\n",
|
||||
" \n",
|
||||
" self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
|
||||
" batch_first=True, bidirectional=True)\n",
|
||||
" self.dropout_layer = nn.Dropout(dropout)\n",
|
||||
" self.fc = nn.Linear(hidden_size, num_classes)\n",
|
||||
" \n",
|
||||
" def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n",
|
||||
" # chars:[batch_size, max_len]\n",
|
||||
" # seq_len: [batch_size, ]\n",
|
||||
" chars = self.embed(chars)\n",
|
||||
" outputs, _ = self.lstm(chars, seq_len)\n",
|
||||
" outputs = self.dropout_layer(outputs)\n",
|
||||
" outputs, _ = torch.max(outputs, dim=1)\n",
|
||||
" outputs = self.fc(outputs)\n",
|
||||
" \n",
|
||||
" return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n",
|
||||
"\n",
|
||||
"# 初始化模型\n",
|
||||
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (5) 训练模型\n",
|
||||
"fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import Trainer\n",
|
||||
"from fastNLP import CrossEntropyLoss\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from fastNLP import AccuracyMetric\n",
|
||||
"\n",
|
||||
"loss = CrossEntropyLoss()\n",
|
||||
"optimizer = Adam(model.parameters(), lr=0.001)\n",
|
||||
"metric = AccuracyMetric()\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
|
||||
"\n",
|
||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
|
||||
" optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
|
||||
" metrics=metric, device=device)\n",
|
||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
|
||||
"\n",
|
||||
"# 在测试集上测试一下模型的性能\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"print(\"Performance on test is:\")\n",
|
||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 使用Bert进行文本分类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 只需要切换一下Embedding即可\n",
|
||||
"from fastNLP.embeddings import BertEmbedding\n",
|
||||
"\n",
|
||||
"# 这里为了演示一下效果,所以默认Bert不更新权重\n",
|
||||
"bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
|
||||
"model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from fastNLP import Trainer\n",
|
||||
"from fastNLP import CrossEntropyLoss\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from fastNLP import AccuracyMetric\n",
|
||||
"\n",
|
||||
"loss = CrossEntropyLoss()\n",
|
||||
"optimizer = Adam(model.parameters(), lr=2e-5)\n",
|
||||
"metric = AccuracyMetric()\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
|
||||
"\n",
|
||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
|
||||
" optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
|
||||
" metrics=metric, device=device, n_epochs=3)\n",
|
||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
|
||||
"\n",
|
||||
"# 在测试集上测试一下模型的性能\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"print(\"Performance on test is:\")\n",
|
||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 基于词进行文本分类"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"由于汉字中没有显示的字与字的边界,一般需要通过分词器先将句子进行分词操作。\n",
|
||||
"下面的例子演示了如何不基于fastNLP已有的数据读取、预处理代码进行文本分类。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (1) 读取数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"这里我们继续以之前的数据为例,但这次我们不使用fastNLP自带的数据读取代码 "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.io import ChnSentiCorpLoader\n",
|
||||
"\n",
|
||||
"loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n",
|
||||
"data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"下面我们先定义一个read_file_to_dataset的函数, 即给定一个文件路径,读取其中的内容,并返回一个DataSet。然后我们将所有的DataSet放入到DataBundle对象中来方便接下来的预处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from fastNLP import DataSet, Instance\n",
|
||||
"from fastNLP.io import DataBundle\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def read_file_to_dataset(fp):\n",
|
||||
" ds = DataSet()\n",
|
||||
" with open(fp, 'r') as f:\n",
|
||||
" f.readline() # 第一行是title名称,忽略掉\n",
|
||||
" for line in f:\n",
|
||||
" line = line.strip()\n",
|
||||
" target, chars = line.split('\\t')\n",
|
||||
" ins = Instance(target=target, raw_chars=chars)\n",
|
||||
" ds.append(ins)\n",
|
||||
" return ds\n",
|
||||
"\n",
|
||||
"data_bundle = DataBundle()\n",
|
||||
"for name in ['train.tsv', 'dev.tsv', 'test.tsv']:\n",
|
||||
" fp = os.path.join(data_dir, name)\n",
|
||||
" ds = read_file_to_dataset(fp)\n",
|
||||
" data_bundle.set_dataset(name=name.split('.')[0], dataset=ds)\n",
|
||||
"\n",
|
||||
"print(data_bundle) # 查看以下数据集的情况\n",
|
||||
"# In total 3 datasets:\n",
|
||||
"# train has 9600 instances.\n",
|
||||
"# dev has 1200 instances.\n",
|
||||
"# test has 1200 instances."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (2) 数据预处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"在这里,我们首先把句子通过 [fastHan](http://gitee.com/fastnlp/fastHan) 进行分词操作,然后创建词表,并将词语转换为序号。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastHan import FastHan\n",
|
||||
"from fastNLP import Vocabulary\n",
|
||||
"\n",
|
||||
"model=FastHan()\n",
|
||||
"# model.set_device('cuda')\n",
|
||||
"\n",
|
||||
"# 定义分词处理操作\n",
|
||||
"def word_seg(ins):\n",
|
||||
" raw_chars = ins['raw_chars']\n",
|
||||
" # 由于有些句子比较长,我们只截取前128个汉字\n",
|
||||
" raw_words = model(raw_chars[:128], target='CWS')[0]\n",
|
||||
" return raw_words\n",
|
||||
"\n",
|
||||
"for name, ds in data_bundle.iter_datasets():\n",
|
||||
" # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n",
|
||||
" ds.apply(word_seg, new_field_name='raw_words')\n",
|
||||
" # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n",
|
||||
" # 同时我们增加一个seq_len的field\n",
|
||||
" ds.add_seq_len('raw_words')\n",
|
||||
"\n",
|
||||
"vocab = Vocabulary()\n",
|
||||
"\n",
|
||||
"# 对raw_words列创建词表, 建议把非训练集的dataset放在no_create_entry_dataset参数中\n",
|
||||
"# 也可以通过add_word(), add_word_lst()等建立词表,请参考http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_2_vocabulary.html\n",
|
||||
"vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_words', \n",
|
||||
" no_create_entry_dataset=[data_bundle.get_dataset('dev'), \n",
|
||||
" data_bundle.get_dataset('test')]) \n",
|
||||
"\n",
|
||||
"# 将建立好词表的Vocabulary用于对raw_words列建立词表,并把转为序号的列存入到words列\n",
|
||||
"vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
|
||||
" data_bundle.get_dataset('test'), field_name='raw_words', new_field_name='words')\n",
|
||||
"\n",
|
||||
"# 建立target的词表,target的词表一般不需要padding和unknown\n",
|
||||
"target_vocab = Vocabulary(padding=None, unknown=None) \n",
|
||||
"# 一般情况下我们可以只用训练集建立target的词表\n",
|
||||
"target_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='target') \n",
|
||||
"# 如果没有传递new_field_name, 则默认覆盖原词表\n",
|
||||
"target_vocab.index_dataset(data_bundle.get_dataset('train'), data_bundle.get_dataset('dev'), \n",
|
||||
" data_bundle.get_dataset('test'), field_name='target')\n",
|
||||
"\n",
|
||||
"# 我们可以把词表保存到data_bundle中,方便之后使用\n",
|
||||
"data_bundle.set_vocab(field_name='words', vocab=vocab)\n",
|
||||
"data_bundle.set_vocab(field_name='target', vocab=target_vocab)\n",
|
||||
"\n",
|
||||
"# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n",
|
||||
"# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n",
|
||||
"data_bundle.set_target('target')\n",
|
||||
"data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n",
|
||||
"# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n",
|
||||
"# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n",
|
||||
"\n",
|
||||
"print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n",
|
||||
"\n",
|
||||
"# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n",
|
||||
"data_bundle.rename_field('words', 'chars')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### (3) 选择预训练词向量"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"这里我们选择腾讯的预训练中文词向量,可以在 [腾讯词向量](https://ai.tencent.com/ailab/nlp/en/embedding.html) 处下载并解压。这里我们不能直接使用BERT,因为BERT是基于中文字进行预训练的。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP.embeddings import StaticEmbedding\n",
|
||||
"\n",
|
||||
"word2vec_embed = StaticEmbedding(data_bundle.get_vocab('words'), \n",
|
||||
" model_dir_or_name='/path/to/Tencent_AILab_ChineseEmbedding.txt')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from fastNLP import Trainer\n",
|
||||
"from fastNLP import CrossEntropyLoss\n",
|
||||
"from torch.optim import Adam\n",
|
||||
"from fastNLP import AccuracyMetric\n",
|
||||
"\n",
|
||||
"# 初始化模型\n",
|
||||
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n",
|
||||
"\n",
|
||||
"# 开始训练\n",
|
||||
"loss = CrossEntropyLoss()\n",
|
||||
"optimizer = Adam(model.parameters(), lr=0.001)\n",
|
||||
"metric = AccuracyMetric()\n",
|
||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n",
|
||||
"\n",
|
||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
|
||||
" optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
|
||||
" metrics=metric, device=device)\n",
|
||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n",
|
||||
"\n",
|
||||
"# 在测试集上测试一下模型的性能\n",
|
||||
"from fastNLP import Tester\n",
|
||||
"print(\"Performance on test is:\")\n",
|
||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
|
||||
"tester.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
27
docs/source/_templates/versions.html
Normal file
27
docs/source/_templates/versions.html
Normal file
@ -0,0 +1,27 @@
|
||||
{%- if current_version %}
|
||||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||||
<span class="fa fa-book"> Other Versions</span>
|
||||
{{ current_version.name }}
|
||||
<span class="fa fa-caret-down"></span>
|
||||
</span>
|
||||
<div class="rst-other-versions">
|
||||
{%- if versions.tags %}
|
||||
<dl>
|
||||
<dt>Tags</dt>
|
||||
{%- for item in versions.tags %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
{%- if versions.branches %}
|
||||
<dl>
|
||||
<dt>Branches</dt>
|
||||
{%- for item in versions.branches %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
</div>
|
||||
</div>
|
||||
{%- endif %}
|
@ -20,13 +20,13 @@ sys.path.insert(0, os.path.abspath('../../'))
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'fastNLP'
|
||||
copyright = '2020, xpqiu'
|
||||
author = 'xpqiu'
|
||||
copyright = '2022, fastNLP'
|
||||
author = 'fastNLP'
|
||||
|
||||
# The short X.Y version
|
||||
version = '0.6.0'
|
||||
version = '1.0'
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = '0.6.0'
|
||||
release = '1.0.0-alpha'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@ -42,7 +42,10 @@ extensions = [
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.todo'
|
||||
'sphinx.ext.todo',
|
||||
'sphinx_autodoc_typehints',
|
||||
'sphinx_multiversion',
|
||||
'nbsphinx',
|
||||
]
|
||||
|
||||
autodoc_default_options = {
|
||||
@ -51,7 +54,12 @@ autodoc_default_options = {
|
||||
'undoc-members': False,
|
||||
}
|
||||
|
||||
add_module_names = False
|
||||
autosummary_ignore_module_all = False
|
||||
# autodoc_typehints = "description"
|
||||
autoclass_content = "class"
|
||||
typehints_fully_qualified = False
|
||||
typehints_defaults = "comma"
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
@ -110,12 +118,16 @@ html_static_path = ['_static']
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
html_sidebars = {
|
||||
'**': [
|
||||
'versions.html',
|
||||
],
|
||||
}
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'fastNLP doc'
|
||||
htmlhelp_basename = 'fastNLP'
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
@ -140,17 +152,14 @@ latex_elements = {
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'fastNLP.tex', 'fastNLP Documentation',
|
||||
'xpqiu', 'manual'),
|
||||
]
|
||||
latex_documents = []
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'fastnlp', 'fastNLP Documentation',
|
||||
(master_doc, 'fastNLP', 'fastNLP Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
@ -161,10 +170,12 @@ man_pages = [
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'fastNLP', 'fastNLP Documentation',
|
||||
author, 'fastNLP', 'One line description of project.',
|
||||
author, 'fastNLP', 'A fast NLP tool for programming.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
# -- Options for Multiversions ----------------------------------------------
|
||||
smv_latest_version = 'dev0.8.0'
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
def maybe_skip_member(app, what, name, obj, skip, options):
|
||||
@ -174,7 +185,7 @@ def maybe_skip_member(app, what, name, obj, skip, options):
|
||||
return False
|
||||
if name.startswith("_"):
|
||||
return True
|
||||
return False
|
||||
return skip
|
||||
|
||||
|
||||
def setup(app):
|
||||
|
@ -1,7 +0,0 @@
|
||||
fastNLP.core.batch
|
||||
==================
|
||||
|
||||
.. automodule:: fastNLP.core.batch
|
||||
:members: BatchIter, DataSetIter, TorchLoaderIter
|
||||
:inherited-members:
|
||||
|
@ -1,7 +0,0 @@
|
||||
fastNLP.core.callback
|
||||
=====================
|
||||
|
||||
.. automodule:: fastNLP.core.callback
|
||||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError
|
||||
:inherited-members:
|
||||
|
7
docs/source/fastNLP.core.callbacks.callback.rst
Normal file
7
docs/source/fastNLP.core.callbacks.callback.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.callback module
|
||||
======================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.callback_event.rst
Normal file
7
docs/source/fastNLP.core.callbacks.callback_event.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.callback\_event module
|
||||
=============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.callback_event
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.callback_manager.rst
Normal file
7
docs/source/fastNLP.core.callbacks.callback_manager.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.callback\_manager module
|
||||
===============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.callback_manager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.checkpoint\_callback module
|
||||
==================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.checkpoint_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.early\_stop\_callback module
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.early_stop_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.fitlog_callback.rst
Normal file
7
docs/source/fastNLP.core.callbacks.fitlog_callback.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.fitlog\_callback module
|
||||
==============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.fitlog_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.has\_monitor\_callback module
|
||||
====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.has_monitor_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.load\_best\_model\_callback module
|
||||
=========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.load_best_model_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.lr\_scheduler\_callback module
|
||||
=====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.lr_scheduler_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.more\_evaluate\_callback module
|
||||
======================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.more_evaluate_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.progress_callback.rst
Normal file
7
docs/source/fastNLP.core.callbacks.progress_callback.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.progress\_callback module
|
||||
================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.progress_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
36
docs/source/fastNLP.core.callbacks.rst
Normal file
36
docs/source/fastNLP.core.callbacks.rst
Normal file
@ -0,0 +1,36 @@
|
||||
fastNLP.core.callbacks package
|
||||
==============================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.callbacks.torch_callbacks
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.callbacks.callback
|
||||
fastNLP.core.callbacks.callback_event
|
||||
fastNLP.core.callbacks.callback_manager
|
||||
fastNLP.core.callbacks.checkpoint_callback
|
||||
fastNLP.core.callbacks.early_stop_callback
|
||||
fastNLP.core.callbacks.fitlog_callback
|
||||
fastNLP.core.callbacks.has_monitor_callback
|
||||
fastNLP.core.callbacks.load_best_model_callback
|
||||
fastNLP.core.callbacks.lr_scheduler_callback
|
||||
fastNLP.core.callbacks.more_evaluate_callback
|
||||
fastNLP.core.callbacks.progress_callback
|
||||
fastNLP.core.callbacks.timer_callback
|
||||
fastNLP.core.callbacks.topk_saver
|
||||
fastNLP.core.callbacks.utils
|
7
docs/source/fastNLP.core.callbacks.timer_callback.rst
Normal file
7
docs/source/fastNLP.core.callbacks.timer_callback.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.timer\_callback module
|
||||
=============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.timer_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.topk_saver.rst
Normal file
7
docs/source/fastNLP.core.callbacks.topk_saver.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.topk\_saver module
|
||||
=========================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.topk_saver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
16
docs/source/fastNLP.core.callbacks.torch_callbacks.rst
Normal file
16
docs/source/fastNLP.core.callbacks.torch_callbacks.rst
Normal file
@ -0,0 +1,16 @@
|
||||
fastNLP.core.callbacks.torch\_callbacks package
|
||||
===============================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.torch_callbacks
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback
|
||||
fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.torch\_callbacks.torch\_grad\_clip\_callback module
|
||||
==========================================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.torch\_callbacks.torch\_lr\_sched\_callback module
|
||||
=========================================================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.callbacks.utils.rst
Normal file
7
docs/source/fastNLP.core.callbacks.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.callbacks.utils module
|
||||
===================================
|
||||
|
||||
.. automodule:: fastNLP.core.callbacks.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.collators.collator.rst
Normal file
7
docs/source/fastNLP.core.collators.collator.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.collator module
|
||||
======================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.collator
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.collators.packer_unpacker.rst
Normal file
7
docs/source/fastNLP.core.collators.packer_unpacker.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.packer\_unpacker module
|
||||
==============================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.packer_unpacker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.exceptions module
|
||||
================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.get\_padder module
|
||||
=================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.get_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.jittor\_padder module
|
||||
====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.jittor_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.numpy\_padder module
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.numpy_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.oneflow\_padder module
|
||||
=====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.oneflow_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.collators.padders.padder.rst
Normal file
7
docs/source/fastNLP.core.collators.padders.padder.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.padder module
|
||||
============================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.paddle\_padder module
|
||||
====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.paddle_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.raw\_padder module
|
||||
=================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.raw_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
25
docs/source/fastNLP.core.collators.padders.rst
Normal file
25
docs/source/fastNLP.core.collators.padders.rst
Normal file
@ -0,0 +1,25 @@
|
||||
fastNLP.core.collators.padders package
|
||||
======================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.collators.padders.exceptions
|
||||
fastNLP.core.collators.padders.get_padder
|
||||
fastNLP.core.collators.padders.jittor_padder
|
||||
fastNLP.core.collators.padders.numpy_padder
|
||||
fastNLP.core.collators.padders.oneflow_padder
|
||||
fastNLP.core.collators.padders.padder
|
||||
fastNLP.core.collators.padders.paddle_padder
|
||||
fastNLP.core.collators.padders.raw_padder
|
||||
fastNLP.core.collators.padders.torch_padder
|
||||
fastNLP.core.collators.padders.torch_utils
|
||||
fastNLP.core.collators.padders.utils
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.torch\_padder module
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.torch_padder
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.torch\_utils module
|
||||
==================================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.torch_utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.collators.padders.utils.rst
Normal file
7
docs/source/fastNLP.core.collators.padders.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.collators.padders.utils module
|
||||
===========================================
|
||||
|
||||
.. automodule:: fastNLP.core.collators.padders.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
24
docs/source/fastNLP.core.collators.rst
Normal file
24
docs/source/fastNLP.core.collators.rst
Normal file
@ -0,0 +1,24 @@
|
||||
fastNLP.core.collators package
|
||||
==============================
|
||||
|
||||
.. automodule:: fastNLP.core.collators
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.collators.padders
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.collators.collator
|
||||
fastNLP.core.collators.packer_unpacker
|
@ -1,7 +0,0 @@
|
||||
fastNLP.core.const
|
||||
==================
|
||||
|
||||
.. automodule:: fastNLP.core.const
|
||||
:members: Const
|
||||
:inherited-members:
|
||||
|
7
docs/source/fastNLP.core.controllers.evaluator.rst
Normal file
7
docs/source/fastNLP.core.controllers.evaluator.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.evaluator module
|
||||
=========================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.evaluator
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.loops.evaluate\_batch\_loop module
|
||||
===========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.loops.evaluate_batch_loop
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.controllers.loops.loop.rst
Normal file
7
docs/source/fastNLP.core.controllers.loops.loop.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.loops.loop module
|
||||
==========================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.loops.loop
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
17
docs/source/fastNLP.core.controllers.loops.rst
Normal file
17
docs/source/fastNLP.core.controllers.loops.rst
Normal file
@ -0,0 +1,17 @@
|
||||
fastNLP.core.controllers.loops package
|
||||
======================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.loops
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.controllers.loops.evaluate_batch_loop
|
||||
fastNLP.core.controllers.loops.loop
|
||||
fastNLP.core.controllers.loops.train_batch_loop
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.loops.train\_batch\_loop module
|
||||
========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.loops.train_batch_loop
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
25
docs/source/fastNLP.core.controllers.rst
Normal file
25
docs/source/fastNLP.core.controllers.rst
Normal file
@ -0,0 +1,25 @@
|
||||
fastNLP.core.controllers package
|
||||
================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.controllers.loops
|
||||
fastNLP.core.controllers.utils
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.controllers.evaluator
|
||||
fastNLP.core.controllers.trainer
|
7
docs/source/fastNLP.core.controllers.trainer.rst
Normal file
7
docs/source/fastNLP.core.controllers.trainer.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.trainer module
|
||||
=======================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
16
docs/source/fastNLP.core.controllers.utils.rst
Normal file
16
docs/source/fastNLP.core.controllers.utils.rst
Normal file
@ -0,0 +1,16 @@
|
||||
fastNLP.core.controllers.utils package
|
||||
======================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.controllers.utils.state
|
||||
fastNLP.core.controllers.utils.utils
|
7
docs/source/fastNLP.core.controllers.utils.state.rst
Normal file
7
docs/source/fastNLP.core.controllers.utils.state.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.utils.state module
|
||||
===========================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.utils.state
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.controllers.utils.utils.rst
Normal file
7
docs/source/fastNLP.core.controllers.utils.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.controllers.utils.utils module
|
||||
===========================================
|
||||
|
||||
.. automodule:: fastNLP.core.controllers.utils.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.jittor\_dataloader.fdl module
|
||||
======================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.jittor_dataloader.fdl
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
15
docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst
Normal file
15
docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst
Normal file
@ -0,0 +1,15 @@
|
||||
fastNLP.core.dataloaders.jittor\_dataloader package
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.jittor_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.jittor_dataloader.fdl
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.oneflow\_dataloader.fdl module
|
||||
=======================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader.fdl
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
15
docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst
Normal file
15
docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst
Normal file
@ -0,0 +1,15 @@
|
||||
fastNLP.core.dataloaders.oneflow\_dataloader package
|
||||
====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.oneflow_dataloader.fdl
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.paddle\_dataloader.fdl module
|
||||
======================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.paddle_dataloader.fdl
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
15
docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst
Normal file
15
docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst
Normal file
@ -0,0 +1,15 @@
|
||||
fastNLP.core.dataloaders.paddle\_dataloader package
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.paddle_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.paddle_dataloader.fdl
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.prepare\_dataloader module
|
||||
===================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.prepare_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
27
docs/source/fastNLP.core.dataloaders.rst
Normal file
27
docs/source/fastNLP.core.dataloaders.rst
Normal file
@ -0,0 +1,27 @@
|
||||
fastNLP.core.dataloaders package
|
||||
================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.jittor_dataloader
|
||||
fastNLP.core.dataloaders.oneflow_dataloader
|
||||
fastNLP.core.dataloaders.paddle_dataloader
|
||||
fastNLP.core.dataloaders.torch_dataloader
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.prepare_dataloader
|
||||
fastNLP.core.dataloaders.utils
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.torch\_dataloader.fdl module
|
||||
=====================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.torch_dataloader.fdl
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.torch\_dataloader.mix\_dataloader module
|
||||
=================================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.torch_dataloader.mix_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
16
docs/source/fastNLP.core.dataloaders.torch_dataloader.rst
Normal file
16
docs/source/fastNLP.core.dataloaders.torch_dataloader.rst
Normal file
@ -0,0 +1,16 @@
|
||||
fastNLP.core.dataloaders.torch\_dataloader package
|
||||
==================================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.torch_dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataloaders.torch_dataloader.fdl
|
||||
fastNLP.core.dataloaders.torch_dataloader.mix_dataloader
|
7
docs/source/fastNLP.core.dataloaders.utils.rst
Normal file
7
docs/source/fastNLP.core.dataloaders.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataloaders.utils module
|
||||
=====================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataloaders.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.dataset.dataset.rst
Normal file
7
docs/source/fastNLP.core.dataset.dataset.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataset.dataset module
|
||||
===================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataset.dataset
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.dataset.field.rst
Normal file
7
docs/source/fastNLP.core.dataset.field.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataset.field module
|
||||
=================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataset.field
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.dataset.instance.rst
Normal file
7
docs/source/fastNLP.core.dataset.instance.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.dataset.instance module
|
||||
====================================
|
||||
|
||||
.. automodule:: fastNLP.core.dataset.instance
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -1,7 +1,17 @@
|
||||
fastNLP.core.dataset
|
||||
====================
|
||||
fastNLP.core.dataset package
|
||||
============================
|
||||
|
||||
.. automodule:: fastNLP.core.dataset
|
||||
:members: DataSet
|
||||
:inherited-members:
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.dataset.dataset
|
||||
fastNLP.core.dataset.field
|
||||
fastNLP.core.dataset.instance
|
||||
|
7
docs/source/fastNLP.core.drivers.choose_driver.rst
Normal file
7
docs/source/fastNLP.core.drivers.choose_driver.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.choose\_driver module
|
||||
==========================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.choose_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.drivers.driver.rst
Normal file
7
docs/source/fastNLP.core.drivers.driver.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.driver module
|
||||
==================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.jittor\_driver.initialize\_jittor\_driver module
|
||||
=====================================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver.initialize_jittor_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.jittor\_driver.jittor\_driver module
|
||||
=========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver.jittor_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.drivers.jittor_driver.mpi.rst
Normal file
7
docs/source/fastNLP.core.drivers.jittor_driver.mpi.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.jittor\_driver.mpi module
|
||||
==============================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver.mpi
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
19
docs/source/fastNLP.core.drivers.jittor_driver.rst
Normal file
19
docs/source/fastNLP.core.drivers.jittor_driver.rst
Normal file
@ -0,0 +1,19 @@
|
||||
fastNLP.core.drivers.jittor\_driver package
|
||||
===========================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.drivers.jittor_driver.initialize_jittor_driver
|
||||
fastNLP.core.drivers.jittor_driver.jittor_driver
|
||||
fastNLP.core.drivers.jittor_driver.mpi
|
||||
fastNLP.core.drivers.jittor_driver.single_device
|
||||
fastNLP.core.drivers.jittor_driver.utils
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.jittor\_driver.single\_device module
|
||||
=========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver.single_device
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.drivers.jittor_driver.utils.rst
Normal file
7
docs/source/fastNLP.core.drivers.jittor_driver.utils.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.jittor\_driver.utils module
|
||||
================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.jittor_driver.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
7
docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst
Normal file
7
docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst
Normal file
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.oneflow\_driver.ddp module
|
||||
===============================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.oneflow_driver.ddp
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.oneflow\_driver.dist\_utils module
|
||||
=======================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.oneflow_driver.dist_utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.oneflow\_driver.initialize\_oneflow\_driver module
|
||||
=======================================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
@ -0,0 +1,7 @@
|
||||
fastNLP.core.drivers.oneflow\_driver.oneflow\_driver module
|
||||
===========================================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.oneflow_driver.oneflow_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
20
docs/source/fastNLP.core.drivers.oneflow_driver.rst
Normal file
20
docs/source/fastNLP.core.drivers.oneflow_driver.rst
Normal file
@ -0,0 +1,20 @@
|
||||
fastNLP.core.drivers.oneflow\_driver package
|
||||
============================================
|
||||
|
||||
.. automodule:: fastNLP.core.drivers.oneflow_driver
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
fastNLP.core.drivers.oneflow_driver.ddp
|
||||
fastNLP.core.drivers.oneflow_driver.dist_utils
|
||||
fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver
|
||||
fastNLP.core.drivers.oneflow_driver.oneflow_driver
|
||||
fastNLP.core.drivers.oneflow_driver.single_device
|
||||
fastNLP.core.drivers.oneflow_driver.utils
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user