mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
464 lines
16 KiB
Python
464 lines
16 KiB
Python
import warnings
|
||
|
||
import torch
|
||
|
||
warnings.filterwarnings('ignore')
|
||
import os
|
||
|
||
from fastNLP.core.dataset import DataSet
|
||
from .utils import load_url
|
||
from .processor import ModelProcessor
|
||
from fastNLP.io.dataset_loader import _cut_long_sentence
|
||
from fastNLP.io.data_loader import ConllLoader
|
||
from fastNLP.core.instance import Instance
|
||
from ..api.pipeline import Pipeline
|
||
from fastNLP.core.metrics import SpanFPreRecMetric
|
||
from .processor import IndexerProcessor
|
||
|
||
# TODO add pretrain urls
|
||
model_urls = {
|
||
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
|
||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
|
||
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
|
||
}
|
||
|
||
|
||
class ConllCWSReader(object):
|
||
"""Deprecated. Use ConllLoader for all types of conll-format files."""
|
||
|
||
def __init__(self):
|
||
pass
|
||
|
||
def load(self, path, cut_long_sent=False):
|
||
"""
|
||
返回的DataSet只包含raw_sentence这个field,内容为str。
|
||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
|
||
::
|
||
|
||
1 编者按 编者按 NN O 11 nmod:topic
|
||
2 : : PU O 11 punct
|
||
3 7月 7月 NT DATE 4 compound:nn
|
||
4 12日 12日 NT DATE 11 nmod:tmod
|
||
5 , , PU O 11 punct
|
||
|
||
1 这 这 DT O 3 det
|
||
2 款 款 M O 1 mark:clf
|
||
3 飞行 飞行 NN O 8 nsubj
|
||
4 从 从 P O 5 case
|
||
5 外型 外型 NN O 8 nmod:prep
|
||
|
||
"""
|
||
datalist = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
sample = []
|
||
for line in f:
|
||
if line.startswith('\n'):
|
||
datalist.append(sample)
|
||
sample = []
|
||
elif line.startswith('#'):
|
||
continue
|
||
else:
|
||
sample.append(line.strip().split())
|
||
if len(sample) > 0:
|
||
datalist.append(sample)
|
||
|
||
ds = DataSet()
|
||
for sample in datalist:
|
||
# print(sample)
|
||
res = self.get_char_lst(sample)
|
||
if res is None:
|
||
continue
|
||
line = ' '.join(res)
|
||
if cut_long_sent:
|
||
sents = _cut_long_sentence(line)
|
||
else:
|
||
sents = [line]
|
||
for raw_sentence in sents:
|
||
ds.append(Instance(raw_sentence=raw_sentence))
|
||
return ds
|
||
|
||
def get_char_lst(self, sample):
|
||
if len(sample) == 0:
|
||
return None
|
||
text = []
|
||
for w in sample:
|
||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
|
||
if t3 == '_':
|
||
return None
|
||
text.append(t1)
|
||
return text
|
||
|
||
|
||
class ConllxDataLoader(ConllLoader):
|
||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。
|
||
|
||
Deprecated. Use ConllLoader for all types of conll-format files.
|
||
"""
|
||
|
||
def __init__(self):
|
||
headers = [
|
||
'words', 'pos_tags', 'heads', 'labels',
|
||
]
|
||
indexs = [
|
||
1, 3, 6, 7,
|
||
]
|
||
super(ConllxDataLoader, self).__init__(headers=headers, indexes=indexs)
|
||
|
||
|
||
class API:
|
||
def __init__(self):
|
||
self.pipeline = None
|
||
self._dict = None
|
||
|
||
def predict(self, *args, **kwargs):
|
||
"""Do prediction for the given input.
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def test(self, file_path):
|
||
"""Test performance over the given data set.
|
||
|
||
:param str file_path:
|
||
:return: a dictionary of metric values
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
def load(self, path, device):
|
||
if os.path.exists(os.path.expanduser(path)):
|
||
_dict = torch.load(path, map_location='cpu')
|
||
else:
|
||
_dict = load_url(path, map_location='cpu')
|
||
self._dict = _dict
|
||
self.pipeline = _dict['pipeline']
|
||
for processor in self.pipeline.pipeline:
|
||
if isinstance(processor, ModelProcessor):
|
||
processor.set_model_device(device)
|
||
|
||
|
||
class POS(API):
|
||
"""FastNLP API for Part-Of-Speech tagging.
|
||
|
||
:param str model_path: the path to the model.
|
||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.
|
||
|
||
"""
|
||
|
||
def __init__(self, model_path=None, device='cpu'):
|
||
super(POS, self).__init__()
|
||
if model_path is None:
|
||
model_path = model_urls['pos']
|
||
|
||
self.load(model_path, device)
|
||
|
||
def predict(self, content):
|
||
"""predict函数的介绍,
|
||
函数介绍的第二句,这句话不会换行
|
||
|
||
:param content: list of list of str. Each string is a token(word).
|
||
:return answer: list of list of str. Each string is a tag.
|
||
"""
|
||
if not hasattr(self, "pipeline"):
|
||
raise ValueError("You have to load model first.")
|
||
|
||
sentence_list = content
|
||
# 1. 检查sentence的类型
|
||
for sentence in sentence_list:
|
||
if not all((type(obj) == str for obj in sentence)):
|
||
raise ValueError("Input must be list of list of string.")
|
||
|
||
# 2. 组建dataset
|
||
dataset = DataSet()
|
||
dataset.add_field("words", sentence_list)
|
||
|
||
# 3. 使用pipeline
|
||
self.pipeline(dataset)
|
||
|
||
def merge_tag(words_list, tags_list):
|
||
rtn = []
|
||
for words, tags in zip(words_list, tags_list):
|
||
rtn.append([w + "/" + t for w, t in zip(words, tags)])
|
||
return rtn
|
||
|
||
output = dataset.field_arrays["tag"].content
|
||
if isinstance(content, str):
|
||
return output[0]
|
||
elif isinstance(content, list):
|
||
return merge_tag(content, output)
|
||
|
||
def test(self, file_path):
|
||
test_data = ConllxDataLoader().load(file_path)
|
||
|
||
save_dict = self._dict
|
||
tag_vocab = save_dict["tag_vocab"]
|
||
pipeline = save_dict["pipeline"]
|
||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
|
||
pipeline.pipeline = [index_tag] + pipeline.pipeline
|
||
|
||
test_data.rename_field("pos_tags", "tag")
|
||
pipeline(test_data)
|
||
test_data.set_target("truth")
|
||
prediction = test_data.field_arrays["predict"].content
|
||
truth = test_data.field_arrays["truth"].content
|
||
seq_len = test_data.field_arrays["word_seq_origin_len"].content
|
||
|
||
# padding by hand
|
||
max_length = max([len(seq) for seq in prediction])
|
||
for idx in range(len(prediction)):
|
||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
|
||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
|
||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth",
|
||
seq_len="word_seq_origin_len")
|
||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)},
|
||
{"truth": torch.Tensor(truth)})
|
||
test_result = evaluator.get_metric()
|
||
f1 = round(test_result['f'] * 100, 2)
|
||
pre = round(test_result['pre'] * 100, 2)
|
||
rec = round(test_result['rec'] * 100, 2)
|
||
|
||
return {"F1": f1, "precision": pre, "recall": rec}
|
||
|
||
|
||
class CWS(API):
|
||
"""
|
||
中文分词高级接口。
|
||
|
||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
|
||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
|
||
"""
|
||
|
||
def __init__(self, model_path=None, device='cpu'):
|
||
|
||
super(CWS, self).__init__()
|
||
if model_path is None:
|
||
model_path = model_urls['cws']
|
||
|
||
self.load(model_path, device)
|
||
|
||
def predict(self, content):
|
||
"""
|
||
分词接口。
|
||
|
||
:param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如
|
||
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。
|
||
:return: str或List[str], 根据输入的的类型决定。
|
||
"""
|
||
if not hasattr(self, 'pipeline'):
|
||
raise ValueError("You have to load model first.")
|
||
|
||
sentence_list = []
|
||
# 1. 检查sentence的类型
|
||
if isinstance(content, str):
|
||
sentence_list.append(content)
|
||
elif isinstance(content, list):
|
||
sentence_list = content
|
||
|
||
# 2. 组建dataset
|
||
dataset = DataSet()
|
||
dataset.add_field('raw_sentence', sentence_list)
|
||
|
||
# 3. 使用pipeline
|
||
self.pipeline(dataset)
|
||
|
||
output = dataset.get_field('output').content
|
||
if isinstance(content, str):
|
||
return output[0]
|
||
elif isinstance(content, list):
|
||
return output
|
||
|
||
def test(self, filepath):
|
||
"""
|
||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
|
||
分词文件应该为::
|
||
|
||
1 编者按 编者按 NN O 11 nmod:topic
|
||
2 : : PU O 11 punct
|
||
3 7月 7月 NT DATE 4 compound:nn
|
||
4 12日 12日 NT DATE 11 nmod:tmod
|
||
5 , , PU O 11 punct
|
||
|
||
1 这 这 DT O 3 det
|
||
2 款 款 M O 1 mark:clf
|
||
3 飞行 飞行 NN O 8 nsubj
|
||
4 从 从 P O 5 case
|
||
5 外型 外型 NN O 8 nmod:prep
|
||
|
||
以空行分割两个句子,有内容的每行有7列。
|
||
|
||
:param filepath: str, 文件路径路径。
|
||
:return: float, float, float. 分别f1, precision, recall.
|
||
"""
|
||
tag_proc = self._dict['tag_proc']
|
||
cws_model = self.pipeline.pipeline[-2].model
|
||
pipeline = self.pipeline.pipeline[:-2]
|
||
|
||
pipeline.insert(1, tag_proc)
|
||
pp = Pipeline(pipeline)
|
||
|
||
reader = ConllCWSReader()
|
||
|
||
# te_filename = '/home/hyan/ctb3/test.conllx'
|
||
te_dataset = reader.load(filepath)
|
||
pp(te_dataset)
|
||
|
||
from ..core.tester import Tester
|
||
from ..core.metrics import SpanFPreRecMetric
|
||
|
||
tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64,
|
||
verbose=0)
|
||
eval_res = tester.test()
|
||
|
||
f1 = eval_res['SpanFPreRecMetric']['f']
|
||
pre = eval_res['SpanFPreRecMetric']['pre']
|
||
rec = eval_res['SpanFPreRecMetric']['rec']
|
||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
|
||
|
||
return {"F1": f1, "precision": pre, "recall": rec}
|
||
|
||
|
||
class Parser(API):
|
||
def __init__(self, model_path=None, device='cpu'):
|
||
super(Parser, self).__init__()
|
||
if model_path is None:
|
||
model_path = model_urls['parser']
|
||
|
||
self.pos_tagger = POS(device=device)
|
||
self.load(model_path, device)
|
||
|
||
def predict(self, content):
|
||
if not hasattr(self, 'pipeline'):
|
||
raise ValueError("You have to load model first.")
|
||
|
||
# 1. 利用POS得到分词和pos tagging结果
|
||
pos_out = self.pos_tagger.predict(content)
|
||
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]
|
||
|
||
# 2. 组建dataset
|
||
dataset = DataSet()
|
||
dataset.add_field('wp', pos_out)
|
||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
|
||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
|
||
dataset.rename_field("words", "raw_words")
|
||
|
||
# 3. 使用pipeline
|
||
self.pipeline(dataset)
|
||
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred')
|
||
dataset.apply(lambda x: [arc + '/' + label for arc, label in
|
||
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output')
|
||
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
|
||
return dataset.field_arrays['output'].content
|
||
|
||
def load_test_file(self, path):
|
||
def get_one(sample):
|
||
sample = list(map(list, zip(*sample)))
|
||
if len(sample) == 0:
|
||
return None
|
||
for w in sample[7]:
|
||
if w == '_':
|
||
print('Error Sample {}'.format(sample))
|
||
return None
|
||
# return word_seq, pos_seq, head_seq, head_tag_seq
|
||
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
|
||
|
||
datalist = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
sample = []
|
||
for line in f:
|
||
if line.startswith('\n'):
|
||
datalist.append(sample)
|
||
sample = []
|
||
elif line.startswith('#'):
|
||
continue
|
||
else:
|
||
sample.append(line.split('\t'))
|
||
if len(sample) > 0:
|
||
datalist.append(sample)
|
||
|
||
data = [get_one(sample) for sample in datalist]
|
||
data_list = list(filter(lambda x: x is not None, data))
|
||
return data_list
|
||
|
||
def test(self, filepath):
|
||
data = self.load_test_file(filepath)
|
||
|
||
def convert(data):
|
||
BOS = '<BOS>'
|
||
dataset = DataSet()
|
||
for sample in data:
|
||
word_seq = [BOS] + sample[0]
|
||
pos_seq = [BOS] + sample[1]
|
||
heads = [0] + sample[2]
|
||
head_tags = [BOS] + sample[3]
|
||
dataset.append(Instance(raw_words=word_seq,
|
||
pos=pos_seq,
|
||
gold_heads=heads,
|
||
arc_true=heads,
|
||
tags=head_tags))
|
||
return dataset
|
||
|
||
ds = convert(data)
|
||
pp = self.pipeline
|
||
for p in pp:
|
||
if p.field_name == 'word_list':
|
||
p.field_name = 'gold_words'
|
||
elif p.field_name == 'pos_list':
|
||
p.field_name = 'gold_pos'
|
||
# ds.rename_field("words", "raw_words")
|
||
# ds.rename_field("tag", "pos")
|
||
pp(ds)
|
||
head_cor, label_cor, total = 0, 0, 0
|
||
for ins in ds:
|
||
head_gold = ins['gold_heads']
|
||
head_pred = ins['arc_pred']
|
||
length = len(head_gold)
|
||
total += length
|
||
for i in range(length):
|
||
head_cor += 1 if head_pred[i] == head_gold[i] else 0
|
||
uas = head_cor / total
|
||
# print('uas:{:.2f}'.format(uas))
|
||
|
||
for p in pp:
|
||
if p.field_name == 'gold_words':
|
||
p.field_name = 'word_list'
|
||
elif p.field_name == 'gold_pos':
|
||
p.field_name = 'pos_list'
|
||
|
||
return {"USA": round(uas, 5)}
|
||
|
||
|
||
class Analyzer:
|
||
def __init__(self, device='cpu'):
|
||
|
||
self.cws = CWS(device=device)
|
||
self.pos = POS(device=device)
|
||
self.parser = Parser(device=device)
|
||
|
||
def predict(self, content, seg=False, pos=False, parser=False):
|
||
if seg is False and pos is False and parser is False:
|
||
seg = True
|
||
output_dict = {}
|
||
if seg:
|
||
seg_output = self.cws.predict(content)
|
||
output_dict['seg'] = seg_output
|
||
if pos:
|
||
pos_output = self.pos.predict(content)
|
||
output_dict['pos'] = pos_output
|
||
if parser:
|
||
parser_output = self.parser.predict(content)
|
||
output_dict['parser'] = parser_output
|
||
|
||
return output_dict
|
||
|
||
def test(self, filepath):
|
||
output_dict = {}
|
||
if self.cws:
|
||
seg_output = self.cws.test(filepath)
|
||
output_dict['seg'] = seg_output
|
||
if self.pos:
|
||
pos_output = self.pos.test(filepath)
|
||
output_dict['pos'] = pos_output
|
||
if self.parser:
|
||
parser_output = self.parser.test(filepath)
|
||
output_dict['parser'] = parser_output
|
||
|
||
return output_dict
|