update POS tag training script

This commit is contained in:
FengZiYjun 2019-01-07 21:50:52 +08:00
parent 1f4d784068
commit 525adf1c41
3 changed files with 65 additions and 9 deletions

View File

@ -10,7 +10,7 @@ eval_sort_key = 'accuracy'
[model]
rnn_hidden_units = 300
word_emb_dim = 300
word_emb_dim = 100
dropout = 0.5
use_crf = true
print_every_step = 10

View File

@ -1,4 +1,6 @@
import argparse
import os
import pickle
import sys
import torch
@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg'
pickle_path = "save"
def train():
def load_tencent_embed(embed_path, word2id):
hit = 0
with open(embed_path, "rb") as f:
embed_dict = pickle.load(f)
embedding_tensor = torch.randn(len(word2id), 200)
for key in word2id:
if key in embed_dict:
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key])
hit += 1
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id)))
return embedding_tensor
def train(checkpoint=None):
# load config
train_param = ConfigSection()
model_param = ConfigSection()
@ -54,15 +69,21 @@ def train():
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"]))
# define a model
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word)
if checkpoint is None:
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
pre_trained = None
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained)
print(model)
else:
model = torch.load(checkpoint)
# call trainer to train
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dataset, metric_key="f",
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save")
trainer.train()
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save")
trainer.train(load_best_model=True)
# save model & pipeline
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len")
@ -73,10 +94,20 @@ def train():
torch.save(save_dict, "model_pp.pkl")
print("pipeline saved")
def infer():
pass
torch.save(model, "./save/best_model.pkl")
if __name__ == "__main__":
train()
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training")
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model")
args = parser.parse_args()
if args.restart is True:
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
if args.checkpoint is None:
raise RuntimeError("Please provide the checkpoint. -cp ")
train(args.checkpoint)
else:
# 一次训练 python train_pos_tag.py
train()

View File

@ -0,0 +1,25 @@
import pickle
def load_embed(embed_path):
embed_dict = {}
with open(embed_path, "r", encoding="utf-8") as f:
for line in f:
tokens = line.split(" ")
if len(tokens) <= 5:
continue
key = tokens[0]
if len(key) == 1:
value = [float(x) for x in tokens[1:]]
embed_dict[key] = value
return embed_dict
if __name__ == "__main__":
embed_dict = load_embed("/home/zyfeng/data/small.txt")
print(embed_dict.keys())
with open("./char_tencent_embedding.pkl", "wb") as f:
pickle.dump(embed_dict, f)
print("finished")