mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
update POS tag training script
This commit is contained in:
parent
1f4d784068
commit
525adf1c41
@ -10,7 +10,7 @@ eval_sort_key = 'accuracy'
|
||||
|
||||
[model]
|
||||
rnn_hidden_units = 300
|
||||
word_emb_dim = 300
|
||||
word_emb_dim = 100
|
||||
dropout = 0.5
|
||||
use_crf = true
|
||||
print_every_step = 10
|
||||
|
@ -1,4 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
import torch
|
||||
@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg'
|
||||
pickle_path = "save"
|
||||
|
||||
|
||||
def train():
|
||||
def load_tencent_embed(embed_path, word2id):
|
||||
hit = 0
|
||||
with open(embed_path, "rb") as f:
|
||||
embed_dict = pickle.load(f)
|
||||
embedding_tensor = torch.randn(len(word2id), 200)
|
||||
for key in word2id:
|
||||
if key in embed_dict:
|
||||
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key])
|
||||
hit += 1
|
||||
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id)))
|
||||
return embedding_tensor
|
||||
|
||||
|
||||
def train(checkpoint=None):
|
||||
# load config
|
||||
train_param = ConfigSection()
|
||||
model_param = ConfigSection()
|
||||
@ -54,15 +69,21 @@ def train():
|
||||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"]))
|
||||
|
||||
# define a model
|
||||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word)
|
||||
if checkpoint is None:
|
||||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
|
||||
pre_trained = None
|
||||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained)
|
||||
print(model)
|
||||
else:
|
||||
model = torch.load(checkpoint)
|
||||
|
||||
# call trainer to train
|
||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
|
||||
target="truth",
|
||||
seq_lens="word_seq_origin_len"),
|
||||
dev_data=dataset, metric_key="f",
|
||||
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save")
|
||||
trainer.train()
|
||||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save")
|
||||
trainer.train(load_best_model=True)
|
||||
|
||||
# save model & pipeline
|
||||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len")
|
||||
@ -73,10 +94,20 @@ def train():
|
||||
torch.save(save_dict, "model_pp.pkl")
|
||||
print("pipeline saved")
|
||||
|
||||
|
||||
def infer():
|
||||
pass
|
||||
torch.save(model, "./save/best_model.pkl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training")
|
||||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.restart is True:
|
||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
|
||||
if args.checkpoint is None:
|
||||
raise RuntimeError("Please provide the checkpoint. -cp ")
|
||||
train(args.checkpoint)
|
||||
else:
|
||||
# 一次训练 python train_pos_tag.py
|
||||
train()
|
||||
|
25
reproduction/pos_tag_model/utils.py
Normal file
25
reproduction/pos_tag_model/utils.py
Normal file
@ -0,0 +1,25 @@
|
||||
import pickle
|
||||
|
||||
|
||||
def load_embed(embed_path):
|
||||
embed_dict = {}
|
||||
with open(embed_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
tokens = line.split(" ")
|
||||
if len(tokens) <= 5:
|
||||
continue
|
||||
key = tokens[0]
|
||||
if len(key) == 1:
|
||||
value = [float(x) for x in tokens[1:]]
|
||||
embed_dict[key] = value
|
||||
return embed_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
embed_dict = load_embed("/home/zyfeng/data/small.txt")
|
||||
|
||||
print(embed_dict.keys())
|
||||
|
||||
with open("./char_tencent_embedding.pkl", "wb") as f:
|
||||
pickle.dump(embed_dict, f)
|
||||
print("finished")
|
Loading…
Reference in New Issue
Block a user