Ready for V0.3.1

* 升级parser API和模型
* update docs: add new pages for tutorials
* upgrade CWS api download source
* add a new method for dataset field access
* add introduction for bert
* add more unit tests for api/processor
* remove unused test data. Add new test data.
This commit is contained in:
FengZiYjun 2019-02-04 09:44:54 +08:00
parent 986541139a
commit 0c5630bd16
16 changed files with 288 additions and 3412 deletions

View File

@ -1,7 +1,8 @@
fastNLP上手教程
fastNLP 10分钟上手教程
===============
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb
fastNLP提供方便的数据预处理训练和测试模型的功能
DataSet & Instance

View File

@ -2,6 +2,8 @@
FastNLP 1分钟上手教程
=====================
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb
step 1
------

View File

@ -0,0 +1,5 @@
fastNLP 进阶教程
===============
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb

View File

@ -0,0 +1,5 @@
fastNLP 开发者指南
===============
原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md

View File

@ -5,6 +5,7 @@ Installation
.. contents::
:local:
Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt .
Run the following commands to install fastNLP package:

View File

@ -6,4 +6,6 @@ Quickstart
../tutorials/fastnlp_1_minute_tutorial
../tutorials/fastnlp_10tmin_tutorial
../tutorials/fastnlp_advanced_tutorial
../tutorials/fastnlp_developer_guide

View File

@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.api.utils import load_url
from fastNLP.api.processor import ModelProcessor
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader
from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
@ -17,9 +17,9 @@ from fastNLP.api.processor import IndexerProcessor
# TODO add pretrain urls
model_urls = {
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
}
@ -90,38 +90,28 @@ class POS(API):
# 3. 使用pipeline
self.pipeline(dataset)
# def decode_tags(ins):
# pred_tags = ins["tag"]
# chars = ins["words"]
# words = []
# start_idx = 0
# for idx, tag in enumerate(pred_tags):
# if tag[0] == "S":
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
# start_idx = idx + 1
# elif tag[0] == "E":
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
# start_idx = idx + 1
# return words
#
# dataset.apply(decode_tags, new_field_name="tag_output")
def merge_tag(words_list, tags_list):
rtn = []
for words, tags in zip(words_list, tags_list):
rtn.append([w + "/" + t for w, t in zip(words, tags)])
return rtn
output = dataset.field_arrays["tag"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return output
return merge_tag(content, output)
def test(self, file_path):
test_data = ConllxDataLoader().load(file_path)
with open("model_pp_0117.pkl", "rb") as f:
save_dict = torch.load(f)
save_dict = self._dict
tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline
test_data.rename_field("pos_tags", "tag")
pipeline(test_data)
test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content
@ -235,7 +225,7 @@ class CWS(API):
rec = eval_res['BMESF1PreRecMetric']['rec']
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
return f1, pre, rec
return {"F1": f1, "precision": pre, "recall": rec}
class Parser(API):
@ -260,6 +250,7 @@ class Parser(API):
dataset.add_field('wp', pos_out)
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
dataset.rename_field("words", "raw_words")
# 3. 使用pipeline
self.pipeline(dataset)
@ -269,31 +260,74 @@ class Parser(API):
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
return dataset.field_arrays['output'].content
def test(self, filepath):
data = ConllxDataLoader().load(filepath)
ds = DataSet()
for ins1, ins2 in zip(add_seg_tag(data), data):
ds.append(Instance(words=ins1[0], tag=ins1[1],
gold_words=ins2[0], gold_pos=ins2[1],
gold_heads=ins2[2], gold_head_tags=ins2[3]))
def load_test_file(self, path):
def get_one(sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)
data = [get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))
return data_list
def test(self, filepath):
data = self.load_test_file(filepath)
def convert(data):
BOS = '<BOS>'
dataset = DataSet()
for sample in data:
word_seq = [BOS] + sample[0]
pos_seq = [BOS] + sample[1]
heads = [0] + sample[2]
head_tags = [BOS] + sample[3]
dataset.append(Instance(raw_words=word_seq,
pos=pos_seq,
gold_heads=heads,
arc_true=heads,
tags=head_tags))
return dataset
ds = convert(data)
pp = self.pipeline
for p in pp:
if p.field_name == 'word_list':
p.field_name = 'gold_words'
elif p.field_name == 'pos_list':
p.field_name = 'gold_pos'
# ds.rename_field("words", "raw_words")
# ds.rename_field("tag", "pos")
pp(ds)
head_cor, label_cor, total = 0, 0, 0
for ins in ds:
head_gold = ins['gold_heads']
head_pred = ins['heads']
head_pred = ins['arc_pred']
length = len(head_gold)
total += length
for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor / total
print('uas:{:.2f}'.format(uas))
# print('uas:{:.2f}'.format(uas))
for p in pp:
if p.field_name == 'gold_words':
@ -301,7 +335,7 @@ class Parser(API):
elif p.field_name == 'gold_pos':
p.field_name = 'pos_list'
return uas
return {"USA": round(uas, 5)}
class Analyzer:

View File

@ -15,19 +15,40 @@ def chinese_word_segmentation():
print(cws.predict(text))
def chinese_word_segmentation_test():
cws = CWS(device='cpu')
print(cws.test("../../test/data_for_tests/zh_sample.conllx"))
def pos_tagging():
# 输入已分词序列
text = ['编者 按: 7月 12日 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
text = [text[0].split()]
print(text)
pos = POS(device='cpu')
print(pos.predict(text))
def pos_tagging_test():
pos = POS(device='cpu')
print(pos.test("../../test/data_for_tests/zh_sample.conllx"))
def syntactic_parsing():
text = ['编者 按: 7月 12日 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
text = [text[0].split()]
parser = Parser(device='cpu')
print(parser.predict(text))
def syntactic_parsing_test():
parser = Parser(device='cpu')
print(parser.test("../../test/data_for_tests/zh_sample.conllx"))
if __name__ == "__main__":
chinese_word_segmentation()
chinese_word_segmentation_test()
pos_tagging()
pos_tagging_test()
syntactic_parsing()
syntactic_parsing_test()

View File

@ -102,6 +102,7 @@ class PreAppendProcessor(Processor):
[data] + instance[field_name]
"""
def __init__(self, data, field_name, new_added_field_name=None):
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name)
self.data = data
@ -116,6 +117,7 @@ class SliceProcessor(Processor):
从某个field中只取部分内容等价于instance[field_name][start:end:step]
"""
def __init__(self, start, end, step, field_name, new_added_field_name=None):
super(SliceProcessor, self).__init__(field_name, new_added_field_name)
for o in (start, end, step):
@ -132,6 +134,7 @@ class Num2TagProcessor(Processor):
将一句话中的数字转换为某个tag
"""
def __init__(self, tag, field_name, new_added_field_name=None):
"""
@ -163,6 +166,7 @@ class IndexerProcessor(Processor):
给定一个vocabulary , 将指定field转换为index形式指定field应该是一维的list比如
['', '', xxx]
"""
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@ -215,6 +219,7 @@ class SeqLenProcessor(Processor):
根据某个field新增一个sequence length的field取该field的第一维
"""
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input
@ -229,6 +234,7 @@ class SeqLenProcessor(Processor):
from fastNLP.core.utils import _build_args
class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
"""
@ -292,6 +298,7 @@ class Index2WordProcessor(Processor):
将DataSet中某个为index的field根据vocab转换为str
"""
def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab
@ -303,7 +310,6 @@ class Index2WordProcessor(Processor):
class SetTargetProcessor(Processor):
# TODO; remove it.
def __init__(self, *fields, flag=True):
super(SetTargetProcessor, self).__init__(None, None)
self.fields = fields
@ -313,6 +319,7 @@ class SetTargetProcessor(Processor):
dataset.set_target(*self.fields, flag=self.flag)
return dataset
class SetInputProcessor(Processor):
def __init__(self, *fields, flag=True):
super(SetInputProcessor, self).__init__(None, None)

View File

@ -92,6 +92,10 @@ class DataSet(object):
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target)
return data_set
elif isinstance(idx, str):
if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx))
return self.field_arrays[idx]
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

View File

@ -1,3 +1,7 @@
"""
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.
"""
import copy
import json
import math
@ -220,7 +224,23 @@ class BertPooler(nn.Module):
class BertModel(nn.Module):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
"""Bidirectional Embedding Representations from Transformers.
If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT.
sources::
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
Construct a BERT model with pre-trained weights::
model = BertModel.from_pretrained("path/to/weights/directory")
"""

View File

@ -1,5 +1,5 @@
[train]
n_epochs = 1
n_epochs = 20
batch_size = 32
use_cuda = true
use_tqdm=true

View File

@ -1,9 +1,12 @@
import random
import unittest
from fastNLP import Vocabulary
import numpy as np
from fastNLP import Vocabulary, Instance
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \
IndexerProcessor, VocabProcessor, SeqLenProcessor
IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \
SetInputProcessor, VocabIndexerProcessor
from fastNLP.core.dataset import DataSet
@ -53,3 +56,46 @@ class TestProcessor(unittest.TestCase):
ds = proc(ds)
for data in ds.field_arrays["len"].content:
self.assertEqual(data, 30)
def test_ModelProcessor(self):
from fastNLP.models.cnn_text_classification import CNNText
model = CNNText(100, 100, 5)
ins_list = []
for _ in range(64):
seq_len = np.random.randint(5, 30)
ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len))
data_set = DataSet(ins_list)
data_set.set_input("word_seq", "seq_lens")
proc = ModelProcessor(model)
data_set = proc(data_set)
self.assertTrue("pred" in data_set)
def test_Index2WordProcessor(self):
vocab = Vocabulary()
vocab.add_word_lst(["a", "b", "c", "d", "e"])
proc = Index2WordProcessor(vocab, "tag_id", "tag")
data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])])
data_set = proc(data_set)
self.assertTrue("tag" in data_set)
def test_SetTargetProcessor(self):
proc = SetTargetProcessor("a", "b", "c")
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
data_set = proc(data_set)
self.assertTrue(data_set["a"].is_target)
self.assertTrue(data_set["b"].is_target)
self.assertTrue(data_set["c"].is_target)
def test_SetInputProcessor(self):
proc = SetInputProcessor("a", "b", "c")
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
data_set = proc(data_set)
self.assertTrue(data_set["a"].is_input)
self.assertTrue(data_set["b"].is_input)
self.assertTrue(data_set["c"].is_input)
def test_VocabIndexerProcessor(self):
proc = VocabIndexerProcessor("word_seq", "word_ids")
data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])])
data_set = proc(data_set)
self.assertTrue("word_ids" in data_set)

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +0,0 @@
迈向充满希望的新世纪——一九九八年新年讲话
附图片1张

View File

@ -0,0 +1,100 @@
1 上海 _ NR NR _ 3 nsubj _ _
2 积极 _ AD AD _ 3 advmod _ _
3 准备 _ VV VV _ 0 root _ _
4 迎接 _ VV VV _ 3 ccomp _ _
5 欧元 _ NN NN _ 6 nn _ _
6 诞生 _ NN NN _ 4 dobj _ _
1 新华社 _ NR NR _ 7 dep _ _
2 上海 _ NR NR _ 7 dep _ _
3 十二月 _ NT NT _ 7 dep _ _
4 三十日 _ NT NT _ 7 dep _ _
5 电 _ NN NN _ 7 dep _ _
6 _ PU PU _ 7 punct _ _
7 记者 _ NN NN _ 0 root _ _
8 潘清 _ NR NR _ 7 dep _ _
9 _ PU PU _ 7 punct _ _
1 即将 _ AD AD _ 2 advmod _ _
2 诞生 _ VV VV _ 4 rcmod _ _
3 的 _ DEC DEC _ 2 cpm _ _
4 欧元 _ NN NN _ 6 nsubj _ _
5 _ PU PU _ 6 punct _ _
6 引起 _ VV VV _ 0 root _ _
7 了 _ AS AS _ 6 asp _ _
8 上海 _ NR NR _ 14 nn _ _
9 这 _ DT DT _ 14 det _ _
10 个 _ M M _ 9 clf _ _
11 中国 _ NR NR _ 13 nn _ _
12 金融 _ NN NN _ 13 nn _ _
13 中心 _ NN NN _ 14 nn _ _
14 城市 _ NN NN _ 16 assmod _ _
15 的 _ DEG DEG _ 14 assm _ _
16 关注 _ NN NN _ 6 dobj _ _
17 。 _ PU PU _ 6 punct _ _
1 上海 _ NR NR _ 2 nn _ _
2 银行界 _ NN NN _ 4 nsubj _ _
3 纷纷 _ AD AD _ 4 advmod _ _
4 推出 _ VV VV _ 0 root _ _
5 了 _ AS AS _ 4 asp _ _
6 与 _ P P _ 8 prep _ _
7 之 _ PN PN _ 6 pobj _ _
8 相关 _ VA VA _ 15 rcmod _ _
9 的 _ DEC DEC _ 8 cpm _ _
10 外汇 _ NN NN _ 15 nn _ _
11 业务 _ NN NN _ 15 nn _ _
12 品种 _ NN NN _ 15 conj _ _
13 和 _ CC CC _ 15 cc _ _
14 服务 _ NN NN _ 15 nn _ _
15 举措 _ NN NN _ 4 dobj _ _
16 _ PU PU _ 4 punct _ _
17 积极 _ AD AD _ 18 advmod _ _
18 准备 _ VV VV _ 4 dep _ _
19 启动 _ VV VV _ 18 ccomp _ _
20 欧元 _ NN NN _ 21 nn _ _
21 业务 _ NN NN _ 19 dobj _ _
22 。 _ PU PU _ 4 punct _ _
1 一些 _ CD CD _ 8 nummod _ _
2 热衷于 _ VV VV _ 8 rcmod _ _
3 个人 _ NN NN _ 5 nn _ _
4 外汇 _ NN NN _ 5 nn _ _
5 交易 _ NN NN _ 2 dobj _ _
6 的 _ DEC DEC _ 2 cpm _ _
7 上海 _ NR NR _ 8 nn _ _
8 市民 _ NN NN _ 13 nsubj _ _
9 _ PU PU _ 13 punct _ _
10 也 _ AD AD _ 13 advmod _ _
11 对 _ P P _ 13 prep _ _
12 欧元 _ NN NN _ 11 pobj _ _
13 表示 _ VV VV _ 0 root _ _
14 出 _ VV VV _ 13 rcomp _ _
15 极 _ AD AD _ 16 advmod _ _
16 大 _ VA VA _ 18 rcmod _ _
17 的 _ DEC DEC _ 16 cpm _ _
18 兴趣 _ NN NN _ 13 dobj _ _
19 。 _ PU PU _ 13 punct _ _
1 继 _ P P _ 38 prep _ _
2 上海 _ NR NR _ 6 nn _ _
3 大众 _ NR NR _ 6 nn _ _
4 汽车 _ NN NN _ 6 nn _ _
5 有限 _ JJ JJ _ 6 amod _ _
6 公司 _ NN NN _ 13 nsubj _ _
7 十八日 _ NT NT _ 13 tmod _ _
8 在 _ P P _ 13 prep _ _
9 中国 _ NR NR _ 10 nn _ _
10 银行 _ NN NN _ 12 nn _ _
11 上海 _ NR NR _ 12 nn _ _
12 分行 _ NN NN _ 8 pobj _ _
13 开立 _ VV VV _ 19 lccomp _ _
14 上海 _ NR NR _ 16 dep _ _
15 第一 _ OD OD _ 16 ordmod _ _
16 个 _ M M _ 18 clf _ _
17 欧元 _ NN NN _ 18 nn _ _
18 帐户 _ NN NN _ 13 dobj _ _
19 后 _ LC LC _ 1 plmod _ _
20 _ PU PU _ 38 punct _ _
21 工商 _ NN NN _ 28 nn _ _
22 银行 _ NN NN _ 28 conj _ _