fastNLP/legacy/api/api.py
2019-07-09 02:18:11 +08:00

464 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import warnings
import torch
warnings.filterwarnings('ignore')
import os
from fastNLP.core.dataset import DataSet
from .utils import load_url
from .processor import ModelProcessor
from fastNLP.io.dataset_loader import _cut_long_sentence
from fastNLP.io.data_loader import ConllLoader
from fastNLP.core.instance import Instance
from ..api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
from .processor import IndexerProcessor
# TODO add pretrain urls
model_urls = {
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
}
class ConllCWSReader(object):
"""Deprecated. Use ConllLoader for all types of conll-format files."""
def __init__(self):
pass
def load(self, path, cut_long_sent=False):
"""
返回的DataSet只包含raw_sentence这个field内容为str。
假定了输入为conll的格式以空行隔开两个句子每行共7列
::
1 编者按 编者按 NN O 11 nmod:topic
2 PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 PU O 11 punct
1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.strip().split())
if len(sample) > 0:
datalist.append(sample)
ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_char_lst(sample)
if res is None:
continue
line = ' '.join(res)
if cut_long_sent:
sents = _cut_long_sentence(line)
else:
sents = [line]
for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence))
return ds
def get_char_lst(self, sample):
if len(sample) == 0:
return None
text = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
return text
class ConllxDataLoader(ConllLoader):
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。
Deprecated. Use ConllLoader for all types of conll-format files.
"""
def __init__(self):
headers = [
'words', 'pos_tags', 'heads', 'labels',
]
indexs = [
1, 3, 6, 7,
]
super(ConllxDataLoader, self).__init__(headers=headers, indexes=indexs)
class API:
def __init__(self):
self.pipeline = None
self._dict = None
def predict(self, *args, **kwargs):
"""Do prediction for the given input.
"""
raise NotImplementedError
def test(self, file_path):
"""Test performance over the given data set.
:param str file_path:
:return: a dictionary of metric values
"""
raise NotImplementedError
def load(self, path, device):
if os.path.exists(os.path.expanduser(path)):
_dict = torch.load(path, map_location='cpu')
else:
_dict = load_url(path, map_location='cpu')
self._dict = _dict
self.pipeline = _dict['pipeline']
for processor in self.pipeline.pipeline:
if isinstance(processor, ModelProcessor):
processor.set_model_device(device)
class POS(API):
"""FastNLP API for Part-Of-Speech tagging.
:param str model_path: the path to the model.
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.
"""
def __init__(self, model_path=None, device='cpu'):
super(POS, self).__init__()
if model_path is None:
model_path = model_urls['pos']
self.load(model_path, device)
def predict(self, content):
"""predict函数的介绍
函数介绍的第二句,这句话不会换行
:param content: list of list of str. Each string is a token(word).
:return answer: list of list of str. Each string is a tag.
"""
if not hasattr(self, "pipeline"):
raise ValueError("You have to load model first.")
sentence_list = content
# 1. 检查sentence的类型
for sentence in sentence_list:
if not all((type(obj) == str for obj in sentence)):
raise ValueError("Input must be list of list of string.")
# 2. 组建dataset
dataset = DataSet()
dataset.add_field("words", sentence_list)
# 3. 使用pipeline
self.pipeline(dataset)
def merge_tag(words_list, tags_list):
rtn = []
for words, tags in zip(words_list, tags_list):
rtn.append([w + "/" + t for w, t in zip(words, tags)])
return rtn
output = dataset.field_arrays["tag"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return merge_tag(content, output)
def test(self, file_path):
test_data = ConllxDataLoader().load(file_path)
save_dict = self._dict
tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline
test_data.rename_field("pos_tags", "tag")
pipeline(test_data)
test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content
truth = test_data.field_arrays["truth"].content
seq_len = test_data.field_arrays["word_seq_origin_len"].content
# padding by hand
max_length = max([len(seq) for seq in prediction])
for idx in range(len(prediction)):
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth",
seq_len="word_seq_origin_len")
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)},
{"truth": torch.Tensor(truth)})
test_result = evaluator.get_metric()
f1 = round(test_result['f'] * 100, 2)
pre = round(test_result['pre'] * 100, 2)
rec = round(test_result['rec'] * 100, 2)
return {"F1": f1, "precision": pre, "recall": rec}
class CWS(API):
"""
中文分词高级接口。
:param model_path: 当model_path为None使用默认位置的model。如果默认位置不存在则自动下载模型
:param device: str可以为'cpu', 'cuda''cuda:0'等。会将模型load到相应device进行推断。
"""
def __init__(self, model_path=None, device='cpu'):
super(CWS, self).__init__()
if model_path is None:
model_path = model_urls['cws']
self.load(model_path, device)
def predict(self, content):
"""
分词接口。
:param content: str或List[str], 例如: "中文分词很重要!" 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。
:return: str或List[str], 根据输入的的类型决定。
"""
if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.")
sentence_list = []
# 1. 检查sentence的类型
if isinstance(content, str):
sentence_list.append(content)
elif isinstance(content, list):
sentence_list = content
# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)
# 3. 使用pipeline
self.pipeline(dataset)
output = dataset.get_field('output').content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return output
def test(self, filepath):
"""
传入一个分词文件路径返回该数据集上分词f1, precision, recall。
分词文件应该为::
1 编者按 编者按 NN O 11 nmod:topic
2 PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 PU O 11 punct
1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
以空行分割两个句子有内容的每行有7列。
:param filepath: str, 文件路径路径。
:return: float, float, float. 分别f1, precision, recall.
"""
tag_proc = self._dict['tag_proc']
cws_model = self.pipeline.pipeline[-2].model
pipeline = self.pipeline.pipeline[:-2]
pipeline.insert(1, tag_proc)
pp = Pipeline(pipeline)
reader = ConllCWSReader()
# te_filename = '/home/hyan/ctb3/test.conllx'
te_dataset = reader.load(filepath)
pp(te_dataset)
from ..core.tester import Tester
from ..core.metrics import SpanFPreRecMetric
tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64,
verbose=0)
eval_res = tester.test()
f1 = eval_res['SpanFPreRecMetric']['f']
pre = eval_res['SpanFPreRecMetric']['pre']
rec = eval_res['SpanFPreRecMetric']['rec']
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
return {"F1": f1, "precision": pre, "recall": rec}
class Parser(API):
def __init__(self, model_path=None, device='cpu'):
super(Parser, self).__init__()
if model_path is None:
model_path = model_urls['parser']
self.pos_tagger = POS(device=device)
self.load(model_path, device)
def predict(self, content):
if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.")
# 1. 利用POS得到分词和pos tagging结果
pos_out = self.pos_tagger.predict(content)
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]
# 2. 组建dataset
dataset = DataSet()
dataset.add_field('wp', pos_out)
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
dataset.rename_field("words", "raw_words")
# 3. 使用pipeline
self.pipeline(dataset)
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred')
dataset.apply(lambda x: [arc + '/' + label for arc, label in
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output')
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
return dataset.field_arrays['output'].content
def load_test_file(self, path):
def get_one(sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)
data = [get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))
return data_list
def test(self, filepath):
data = self.load_test_file(filepath)
def convert(data):
BOS = '<BOS>'
dataset = DataSet()
for sample in data:
word_seq = [BOS] + sample[0]
pos_seq = [BOS] + sample[1]
heads = [0] + sample[2]
head_tags = [BOS] + sample[3]
dataset.append(Instance(raw_words=word_seq,
pos=pos_seq,
gold_heads=heads,
arc_true=heads,
tags=head_tags))
return dataset
ds = convert(data)
pp = self.pipeline
for p in pp:
if p.field_name == 'word_list':
p.field_name = 'gold_words'
elif p.field_name == 'pos_list':
p.field_name = 'gold_pos'
# ds.rename_field("words", "raw_words")
# ds.rename_field("tag", "pos")
pp(ds)
head_cor, label_cor, total = 0, 0, 0
for ins in ds:
head_gold = ins['gold_heads']
head_pred = ins['arc_pred']
length = len(head_gold)
total += length
for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor / total
# print('uas:{:.2f}'.format(uas))
for p in pp:
if p.field_name == 'gold_words':
p.field_name = 'word_list'
elif p.field_name == 'gold_pos':
p.field_name = 'pos_list'
return {"USA": round(uas, 5)}
class Analyzer:
def __init__(self, device='cpu'):
self.cws = CWS(device=device)
self.pos = POS(device=device)
self.parser = Parser(device=device)
def predict(self, content, seg=False, pos=False, parser=False):
if seg is False and pos is False and parser is False:
seg = True
output_dict = {}
if seg:
seg_output = self.cws.predict(content)
output_dict['seg'] = seg_output
if pos:
pos_output = self.pos.predict(content)
output_dict['pos'] = pos_output
if parser:
parser_output = self.parser.predict(content)
output_dict['parser'] = parser_output
return output_dict
def test(self, filepath):
output_dict = {}
if self.cws:
seg_output = self.cws.test(filepath)
output_dict['seg'] = seg_output
if self.pos:
pos_output = self.pos.test(filepath)
output_dict['pos'] = pos_output
if self.parser:
parser_output = self.parser.test(filepath)
output_dict['parser'] = parser_output
return output_dict