mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
update trainer
This commit is contained in:
parent
c14d9f4d66
commit
3192c9ac66
@ -24,6 +24,9 @@ class Field(object):
|
||||
def __repr__(self):
|
||||
return self.contents().__repr__()
|
||||
|
||||
def new(self, *args, **kwargs):
|
||||
return self.__class__(*args, **kwargs, is_target=self.is_target)
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, text, is_target):
|
||||
"""
|
||||
|
@ -35,6 +35,9 @@ class Instance(object):
|
||||
else:
|
||||
raise KeyError("{} not found".format(name))
|
||||
|
||||
def __setitem__(self, name, field):
|
||||
return self.add_field(name, field)
|
||||
|
||||
def get_length(self):
|
||||
"""Fetch the length of all fields in the instance.
|
||||
|
||||
|
@ -74,7 +74,7 @@ class Tester(object):
|
||||
output_list = []
|
||||
truth_list = []
|
||||
|
||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
|
||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y in data_iterator:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
logger = create_logger(__name__, "./train_test.log")
|
||||
|
||||
logger.disabled = True
|
||||
|
||||
class Trainer(object):
|
||||
"""Operations of training a model, including data loading, gradient descent, and validation.
|
||||
@ -42,7 +42,7 @@ class Trainer(object):
|
||||
"""
|
||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
|
||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
|
||||
"valid_step": 500, "eval_sort_key": None,
|
||||
"valid_step": 500, "eval_sort_key": 'acc',
|
||||
"loss": Loss(None), # used to pass type check
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||
"evaluator": Evaluator()
|
||||
@ -111,13 +111,17 @@ class Trainer(object):
|
||||
else:
|
||||
self._model = network
|
||||
|
||||
print(self._model)
|
||||
|
||||
# define Tester over dev data
|
||||
self.dev_data = None
|
||||
if self.validate:
|
||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
|
||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
|
||||
if self.validator is None:
|
||||
self.validator = self._create_validator(default_valid_args)
|
||||
logger.info("validator defined as {}".format(str(self.validator)))
|
||||
self.dev_data = dev_data
|
||||
|
||||
# optimizer and loss
|
||||
self.define_optimizer()
|
||||
@ -130,7 +134,7 @@ class Trainer(object):
|
||||
|
||||
# main training procedure
|
||||
start = time.time()
|
||||
self.start_time = str(start)
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M'))
|
||||
|
||||
logger.info("training epochs started " + self.start_time)
|
||||
epoch, iters = 1, 0
|
||||
@ -141,15 +145,17 @@ class Trainer(object):
|
||||
|
||||
# prepare mini-batch iterator
|
||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
|
||||
use_cuda=self.use_cuda)
|
||||
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
|
||||
logger.info("prepared data iterator")
|
||||
|
||||
# one forward and backward pass
|
||||
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
|
||||
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
|
||||
|
||||
# validation
|
||||
if self.validate:
|
||||
self.valid_model()
|
||||
self.save_model(self._model, 'training_model_'+self.start_time)
|
||||
epoch += 1
|
||||
|
||||
def _train_step(self, data_iterator, network, **kwargs):
|
||||
"""Training process in one epoch.
|
||||
@ -160,13 +166,16 @@ class Trainer(object):
|
||||
- epoch: int,
|
||||
"""
|
||||
step = kwargs['step']
|
||||
dev_data = kwargs['dev_data']
|
||||
for batch_x, batch_y in data_iterator:
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
loss = self.get_loss(prediction, batch_y)
|
||||
self.grad_backward(loss)
|
||||
if torch.rand(1).item() < 0.001:
|
||||
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
|
||||
for name, p in self._model.named_parameters():
|
||||
if p.requires_grad:
|
||||
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
|
||||
self.update()
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
|
||||
|
||||
@ -183,13 +192,14 @@ class Trainer(object):
|
||||
return step
|
||||
|
||||
def valid_model(self):
|
||||
if dev_data is None:
|
||||
if self.dev_data is None:
|
||||
raise RuntimeError(
|
||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
|
||||
logger.info("validation started")
|
||||
res = self.validator.test(network, dev_data)
|
||||
res = self.validator.test(self._model, self.dev_data)
|
||||
if self.save_best_dev and self.best_eval_result(res):
|
||||
logger.info('save best result! {}'.format(res))
|
||||
print('save best result! {}'.format(res))
|
||||
self.save_model(self._model, 'best_model_'+self.start_time)
|
||||
return res
|
||||
|
||||
@ -282,14 +292,10 @@ class Trainer(object):
|
||||
"""
|
||||
if isinstance(metrics, tuple):
|
||||
loss, metrics = metrics
|
||||
else:
|
||||
metrics = validator.metrics
|
||||
|
||||
if isinstance(metrics, dict):
|
||||
if len(metrics) == 1:
|
||||
accuracy = list(metrics.values())[0]
|
||||
elif self.eval_sort_key is None:
|
||||
raise ValueError('dict format metrics should provide sort key for eval best result')
|
||||
else:
|
||||
accuracy = metrics[self.eval_sort_key]
|
||||
else:
|
||||
|
@ -199,6 +199,8 @@ class BiaffineParser(GraphParser):
|
||||
word_emb_dim,
|
||||
pos_vocab_size,
|
||||
pos_emb_dim,
|
||||
word_hid_dim,
|
||||
pos_hid_dim,
|
||||
rnn_layers,
|
||||
rnn_hidden_size,
|
||||
arc_mlp_size,
|
||||
@ -209,10 +211,15 @@ class BiaffineParser(GraphParser):
|
||||
use_greedy_infer=False):
|
||||
|
||||
super(BiaffineParser, self).__init__()
|
||||
rnn_out_size = 2 * rnn_hidden_size
|
||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
|
||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
|
||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
|
||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
|
||||
self.word_norm = nn.LayerNorm(word_hid_dim)
|
||||
self.pos_norm = nn.LayerNorm(pos_hid_dim)
|
||||
if use_var_lstm:
|
||||
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
|
||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
|
||||
hidden_size=rnn_hidden_size,
|
||||
num_layers=rnn_layers,
|
||||
bias=True,
|
||||
@ -221,7 +228,7 @@ class BiaffineParser(GraphParser):
|
||||
hidden_dropout=dropout,
|
||||
bidirectional=True)
|
||||
else:
|
||||
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
|
||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
|
||||
hidden_size=rnn_hidden_size,
|
||||
num_layers=rnn_layers,
|
||||
bias=True,
|
||||
@ -229,12 +236,13 @@ class BiaffineParser(GraphParser):
|
||||
dropout=dropout,
|
||||
bidirectional=True)
|
||||
|
||||
rnn_out_size = 2 * rnn_hidden_size
|
||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
|
||||
nn.LayerNorm(arc_mlp_size),
|
||||
nn.ELU(),
|
||||
TimestepDropout(p=dropout),)
|
||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
|
||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
|
||||
nn.LayerNorm(label_mlp_size),
|
||||
nn.ELU(),
|
||||
TimestepDropout(p=dropout),)
|
||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
|
||||
@ -242,10 +250,18 @@ class BiaffineParser(GraphParser):
|
||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
|
||||
self.normal_dropout = nn.Dropout(p=dropout)
|
||||
self.use_greedy_infer = use_greedy_infer
|
||||
initial_parameter(self)
|
||||
self.word_norm = nn.LayerNorm(word_emb_dim)
|
||||
self.pos_norm = nn.LayerNorm(pos_emb_dim)
|
||||
self.lstm_norm = nn.LayerNorm(rnn_out_size)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Embedding):
|
||||
continue
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
else:
|
||||
for p in m.parameters():
|
||||
nn.init.normal_(p, 0, 0.01)
|
||||
|
||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
|
||||
"""
|
||||
@ -262,19 +278,21 @@ class BiaffineParser(GraphParser):
|
||||
# prepare embeddings
|
||||
batch_size, seq_len = word_seq.shape
|
||||
# print('forward {} {}'.format(batch_size, seq_len))
|
||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
|
||||
|
||||
# get sequence mask
|
||||
seq_mask = len2masks(word_seq_origin_len, seq_len).long()
|
||||
|
||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
|
||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
|
||||
word, pos = self.word_fc(word), self.pos_fc(pos)
|
||||
word, pos = self.word_norm(word), self.pos_norm(pos)
|
||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
|
||||
del word, pos
|
||||
|
||||
# lstm, extract features
|
||||
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True)
|
||||
feat, _ = self.lstm(x) # -> [N,L,C]
|
||||
feat = self.lstm_norm(feat)
|
||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
|
||||
|
||||
# for arc biaffine
|
||||
# mlp, reduce dim
|
||||
@ -282,6 +300,7 @@ class BiaffineParser(GraphParser):
|
||||
arc_head = self.arc_head_mlp(feat)
|
||||
label_dep = self.label_dep_mlp(feat)
|
||||
label_head = self.label_head_mlp(feat)
|
||||
del feat
|
||||
|
||||
# biaffine arc classifier
|
||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
|
||||
@ -289,7 +308,7 @@ class BiaffineParser(GraphParser):
|
||||
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
|
||||
|
||||
# use gold or predicted arc to predict label
|
||||
if gold_heads is None:
|
||||
if gold_heads is None or not self.training:
|
||||
# use greedy decoding in training
|
||||
if self.training or self.use_greedy_infer:
|
||||
heads = self._greedy_decoder(arc_pred, seq_mask)
|
||||
@ -301,6 +320,7 @@ class BiaffineParser(GraphParser):
|
||||
head_pred = None
|
||||
heads = gold_heads
|
||||
|
||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
|
||||
label_head = label_head[batch_range, heads].contiguous()
|
||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
|
||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}
|
||||
|
@ -1,16 +1,14 @@
|
||||
[train]
|
||||
epochs = -1
|
||||
<<<<<<< HEAD
|
||||
batch_size = 16
|
||||
=======
|
||||
batch_size = 32
|
||||
>>>>>>> update biaffine
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
save_best_dev = true
|
||||
eval_sort_key = "UAS"
|
||||
use_cuda = true
|
||||
model_saved_path = "./save/"
|
||||
print_every_step = 20
|
||||
use_golden_train=true
|
||||
|
||||
[test]
|
||||
save_output = true
|
||||
@ -26,14 +24,17 @@ word_vocab_size = -1
|
||||
word_emb_dim = 100
|
||||
pos_vocab_size = -1
|
||||
pos_emb_dim = 100
|
||||
word_hid_dim = 100
|
||||
pos_hid_dim = 100
|
||||
rnn_layers = 3
|
||||
rnn_hidden_size = 400
|
||||
arc_mlp_size = 500
|
||||
label_mlp_size = 100
|
||||
num_label = -1
|
||||
dropout = 0.33
|
||||
use_var_lstm=true
|
||||
use_var_lstm=false
|
||||
use_greedy_infer=false
|
||||
|
||||
[optim]
|
||||
lr = 2e-3
|
||||
weight_decay = 0.0
|
||||
|
@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import torch
|
||||
import re
|
||||
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
@ -55,10 +56,10 @@ class ConlluDataLoader(object):
|
||||
return ds
|
||||
|
||||
def get_one(self, sample):
|
||||
text = ['<root>']
|
||||
pos_tags = ['<root>']
|
||||
heads = [0]
|
||||
head_tags = ['root']
|
||||
text = []
|
||||
pos_tags = []
|
||||
heads = []
|
||||
head_tags = []
|
||||
for w in sample:
|
||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
|
||||
if t3 == '_':
|
||||
@ -96,12 +97,13 @@ class CTBDataLoader(object):
|
||||
def convert(self, data):
|
||||
dataset = DataSet()
|
||||
for sample in data:
|
||||
word_seq = ["<ROOT>"] + sample[0]
|
||||
pos_seq = ["<ROOT>"] + sample[1]
|
||||
heads = [0] + list(map(int, sample[2]))
|
||||
head_tags = ["ROOT"] + sample[3]
|
||||
word_seq = ["<s>"] + sample[0] + ['</s>']
|
||||
pos_seq = ["<s>"] + sample[1] + ['</s>']
|
||||
heads = [0] + list(map(int, sample[2])) + [0]
|
||||
head_tags = ["<s>"] + sample[3] + ['</s>']
|
||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
|
||||
pos_seq=TextField(pos_seq, is_target=False),
|
||||
gold_heads=SeqLabelField(heads, is_target=False),
|
||||
head_indices=SeqLabelField(heads, is_target=True),
|
||||
head_labels=TextField(head_tags, is_target=True)))
|
||||
return dataset
|
||||
@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/'
|
||||
train_data_name = "train_ctb5.txt"
|
||||
dev_data_name = "dev_ctb5.txt"
|
||||
test_data_name = "test_ctb5.txt"
|
||||
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
|
||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
|
||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
|
||||
loader = CTBDataLoader()
|
||||
|
||||
cfgfile = './cfg.cfg'
|
||||
@ -129,6 +132,10 @@ test_args = ConfigSection()
|
||||
model_args = ConfigSection()
|
||||
optim_args = ConfigSection()
|
||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
|
||||
print('trainre Args:', train_args.data)
|
||||
print('test Args:', test_args.data)
|
||||
print('optim Args:', optim_args.data)
|
||||
|
||||
|
||||
# Pickle Loader
|
||||
def save_data(dirpath, **kwargs):
|
||||
@ -151,9 +158,31 @@ def load_data(dirpath):
|
||||
datas[name] = _pickle.load(f)
|
||||
return datas
|
||||
|
||||
def P2(data, field, length):
|
||||
ds = [ins for ins in data if ins[field].get_length() >= length]
|
||||
data.clear()
|
||||
data.extend(ds)
|
||||
return ds
|
||||
|
||||
def P1(data, field):
|
||||
def reeng(w):
|
||||
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG'
|
||||
def renum(w):
|
||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER'
|
||||
for ins in data:
|
||||
ori = ins[field].contents()
|
||||
s = list(map(renum, map(reeng, ori)))
|
||||
if s != ori:
|
||||
# print(ori)
|
||||
# print(s)
|
||||
# print()
|
||||
ins[field] = ins[field].new(s)
|
||||
return data
|
||||
|
||||
class ParserEvaluator(Evaluator):
|
||||
def __init__(self):
|
||||
def __init__(self, ignore_label):
|
||||
super(ParserEvaluator, self).__init__()
|
||||
self.ignore = ignore_label
|
||||
|
||||
def __call__(self, predict_list, truth_list):
|
||||
head_all, label_all, total_all = 0, 0, 0
|
||||
@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator):
|
||||
label_pred_correct: number of correct predicted labels.
|
||||
total_tokens: number of predicted tokens
|
||||
"""
|
||||
seq_mask *= (head_labels != self.ignore).long()
|
||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask
|
||||
_, label_preds = torch.max(label_pred, dim=2)
|
||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
|
||||
@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator):
|
||||
|
||||
try:
|
||||
data_dict = load_data(processed_datadir)
|
||||
word_v = data_dict['word_v']
|
||||
pos_v = data_dict['pos_v']
|
||||
tag_v = data_dict['tag_v']
|
||||
train_data = data_dict['train_data']
|
||||
dev_data = data_dict['dev_data']
|
||||
test_data = data_dict['test_datas']
|
||||
print('use saved pickles')
|
||||
|
||||
except Exception as _:
|
||||
print('load raw data and preprocess')
|
||||
word_v = Vocabulary(need_default=True, min_freq=2)
|
||||
# use pretrain embedding
|
||||
pos_v = Vocabulary(need_default=True)
|
||||
tag_v = Vocabulary(need_default=False)
|
||||
train_data = loader.load(os.path.join(datadir, train_data_name))
|
||||
dev_data = loader.load(os.path.join(datadir, dev_data_name))
|
||||
test_data = loader.load(os.path.join(datadir, test_data_name))
|
||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
|
||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
|
||||
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v)
|
||||
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)
|
||||
|
||||
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
|
||||
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
|
||||
train_data.set_origin_len("word_seq")
|
||||
dev_data.set_origin_len("word_seq")
|
||||
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl'))
|
||||
word_v.unknown_label = "<OOV>"
|
||||
|
||||
print(train_data[:3])
|
||||
print(len(train_data))
|
||||
print(len(dev_data))
|
||||
# Model
|
||||
model_args['word_vocab_size'] = len(word_v)
|
||||
model_args['pos_vocab_size'] = len(pos_v)
|
||||
model_args['num_label'] = len(tag_v)
|
||||
|
||||
model = BiaffineParser(**model_args.data)
|
||||
model.reset_parameters()
|
||||
|
||||
def train():
|
||||
datasets = (train_data, dev_data, test_data)
|
||||
for ds in datasets:
|
||||
# print('====='*30)
|
||||
P1(ds, 'word_seq')
|
||||
P2(ds, 'word_seq', 5)
|
||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
|
||||
ds.set_origin_len('word_seq')
|
||||
if train_args['use_golden_train']:
|
||||
ds.set_target(gold_heads=False)
|
||||
else:
|
||||
ds.set_target(gold_heads=None)
|
||||
train_args.data.pop('use_golden_train')
|
||||
ignore_label = pos_v['P']
|
||||
|
||||
print(test_data[0])
|
||||
print(len(train_data))
|
||||
print(len(dev_data))
|
||||
print(len(test_data))
|
||||
|
||||
|
||||
|
||||
def train(path):
|
||||
# Trainer
|
||||
trainer = Trainer(**train_args.data)
|
||||
|
||||
def _define_optim(obj):
|
||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
|
||||
lr = optim_args.data['lr']
|
||||
embed_params = set(obj._model.word_embedding.parameters())
|
||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters())
|
||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params]
|
||||
obj._optimizer = torch.optim.Adam([
|
||||
{'params': list(embed_params), 'lr':lr*0.1},
|
||||
{'params': list(decay_params), **optim_args.data},
|
||||
{'params': params}
|
||||
], lr=lr)
|
||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))
|
||||
|
||||
def _update(obj):
|
||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
|
||||
obj._scheduler.step()
|
||||
obj._optimizer.step()
|
||||
|
||||
trainer.define_optimizer = lambda: _define_optim(trainer)
|
||||
trainer.update = lambda: _update(trainer)
|
||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator()))
|
||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))
|
||||
|
||||
# Model
|
||||
model = BiaffineParser(**model_args.data)
|
||||
|
||||
# use pretrain embedding
|
||||
word_v.unknown_label = "<OOV>"
|
||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
|
||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
|
||||
|
||||
model.word_embedding.padding_idx = word_v.padding_idx
|
||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
|
||||
model.pos_embedding.padding_idx = pos_v.padding_idx
|
||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
|
||||
|
||||
try:
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
print('model parameter loaded!')
|
||||
except Exception as _:
|
||||
print("No saved model. Continue.")
|
||||
pass
|
||||
# try:
|
||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
# print('model parameter loaded!')
|
||||
# except Exception as _:
|
||||
# print("No saved model. Continue.")
|
||||
# pass
|
||||
|
||||
# Start training
|
||||
trainer.train(model, train_data, dev_data)
|
||||
@ -258,15 +309,15 @@ def train():
|
||||
print("Model saved!")
|
||||
|
||||
|
||||
def test():
|
||||
def test(path):
|
||||
# Tester
|
||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator())
|
||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))
|
||||
|
||||
# Model
|
||||
model = BiaffineParser(**model_args.data)
|
||||
|
||||
try:
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
ModelLoader.load_pytorch(model, path)
|
||||
print('model parameter loaded!')
|
||||
except Exception as _:
|
||||
print("No saved model. Abort test.")
|
||||
@ -284,11 +335,12 @@ if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
|
||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
|
||||
parser.add_argument('--path', type=str, default='')
|
||||
args = parser.parse_args()
|
||||
if args.mode == 'train':
|
||||
train()
|
||||
train(args.path)
|
||||
elif args.mode == 'test':
|
||||
test()
|
||||
test(args.path)
|
||||
elif args.mode == 'infer':
|
||||
infer()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user