mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
Update POS API
This commit is contained in:
parent
62ea4f7fed
commit
b14dd58828
@ -18,7 +18,7 @@ from fastNLP.api.processor import IndexerProcessor
|
||||
# TODO add pretrain urls
|
||||
model_urls = {
|
||||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
|
||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl",
|
||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
|
||||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,10 @@ def chinese_word_segmentation():
|
||||
|
||||
|
||||
def pos_tagging():
|
||||
# 输入已分词序列
|
||||
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
|
||||
text = [text[0].split()]
|
||||
print(text)
|
||||
pos = POS(device='cpu')
|
||||
print(pos.predict(text))
|
||||
|
||||
@ -26,4 +30,4 @@ def syntactic_parsing():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
syntactic_parsing()
|
||||
pos_tagging()
|
||||
|
@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader
|
||||
from fastNLP.io.dataset_loader import ConllxDataLoader
|
||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id):
|
||||
return embedding_tensor
|
||||
|
||||
|
||||
def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
def train(train_data_path, dev_data_path, checkpoint=None, save=None):
|
||||
# load config
|
||||
train_param = ConfigSection()
|
||||
model_param = ConfigSection()
|
||||
@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
|
||||
# Data Loader
|
||||
print("loading training set...")
|
||||
dataset = ConllxDataLoader().load(train_data_path)
|
||||
dataset = ConllxDataLoader().load(train_data_path, return_dataset=True)
|
||||
print("loading dev set...")
|
||||
dev_data = ConllxDataLoader().load(dev_data_path)
|
||||
dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True)
|
||||
print(dataset)
|
||||
print("================= dataset ready =====================")
|
||||
|
||||
@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
dev_data.rename_field("tag", "truth")
|
||||
|
||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq")
|
||||
tag_proc = VocabIndexerProcessor("truth")
|
||||
tag_proc = VocabIndexerProcessor("truth", is_input=True)
|
||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True)
|
||||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth")
|
||||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len")
|
||||
|
||||
vocab_proc(dataset)
|
||||
tag_proc(dataset)
|
||||
@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
target="truth",
|
||||
seq_lens="word_seq_origin_len"),
|
||||
dev_data=dev_data, metric_key="f",
|
||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117")
|
||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save)
|
||||
trainer.train(load_best_model=True)
|
||||
|
||||
# save model & pipeline
|
||||
@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
|
||||
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
|
||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
|
||||
torch.save(save_dict, "model_pp_0117.pkl")
|
||||
torch.save(save_dict, os.path.join(save, "model_pp.pkl"))
|
||||
print("pipeline saved")
|
||||
|
||||
|
||||
def run_test(test_path):
|
||||
test_data = ZhConllPOSReader().load(test_path)
|
||||
test_data = ConllxDataLoader().load(test_path, return_dataset=True)
|
||||
|
||||
with open("model_pp_0117.pkl", "rb") as f:
|
||||
save_dict = torch.load(f)
|
||||
@ -157,7 +157,7 @@ if __name__ == "__main__":
|
||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
|
||||
if args.checkpoint is None:
|
||||
raise RuntimeError("Please provide the checkpoint. -cp ")
|
||||
train(args.train, args.dev, args.checkpoint)
|
||||
train(args.train, args.dev, args.checkpoint, save=args.save)
|
||||
else:
|
||||
# 一次训练 python train_pos_tag.py
|
||||
train(args.train, args.dev)
|
||||
train(args.train, args.dev, save=args.save)
|
||||
|
Loading…
Reference in New Issue
Block a user