update trainer

This commit is contained in:
yunfan 2018-11-04 17:57:35 +08:00
parent c14d9f4d66
commit 3192c9ac66
7 changed files with 157 additions and 72 deletions

View File

@ -24,6 +24,9 @@ class Field(object):
def __repr__(self):
return self.contents().__repr__()
def new(self, *args, **kwargs):
return self.__class__(*args, **kwargs, is_target=self.is_target)
class TextField(Field):
def __init__(self, text, is_target):
"""

View File

@ -35,6 +35,9 @@ class Instance(object):
else:
raise KeyError("{} not found".format(name))
def __setitem__(self, name, field):
return self.add_field(name, field)
def get_length(self):
"""Fetch the length of all fields in the instance.

View File

@ -74,7 +74,7 @@ class Tester(object):
output_list = []
truth_list = []
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
with torch.no_grad():
for batch_x, batch_y in data_iterator:

View File

@ -1,6 +1,6 @@
import os
import time
from datetime import timedelta
from datetime import timedelta, datetime
import torch
from tensorboardX import SummaryWriter
@ -15,7 +15,7 @@ from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver
logger = create_logger(__name__, "./train_test.log")
logger.disabled = True
class Trainer(object):
"""Operations of training a model, including data loading, gradient descent, and validation.
@ -42,7 +42,7 @@ class Trainer(object):
"""
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"valid_step": 500, "eval_sort_key": None,
"valid_step": 500, "eval_sort_key": 'acc',
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator()
@ -111,13 +111,17 @@ class Trainer(object):
else:
self._model = network
print(self._model)
# define Tester over dev data
self.dev_data = None
if self.validate:
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
if self.validator is None:
self.validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(self.validator)))
self.dev_data = dev_data
# optimizer and loss
self.define_optimizer()
@ -130,7 +134,7 @@ class Trainer(object):
# main training procedure
start = time.time()
self.start_time = str(start)
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M'))
logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0
@ -141,15 +145,17 @@ class Trainer(object):
# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
use_cuda=self.use_cuda, sort_in_batch=True, sort_key='word_seq')
logger.info("prepared data iterator")
# one forward and backward pass
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
iters = self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
# validation
if self.validate:
self.valid_model()
self.save_model(self._model, 'training_model_'+self.start_time)
epoch += 1
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch.
@ -160,13 +166,16 @@ class Trainer(object):
- epoch: int,
"""
step = kwargs['step']
dev_data = kwargs['dev_data']
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
if torch.rand(1).item() < 0.001:
print('[grads at epoch: {:>3} step: {:>4}]'.format(kwargs['epoch'], step))
for name, p in self._model.named_parameters():
if p.requires_grad:
print('\t{} {} {}'.format(name, tuple(p.size()), torch.sum(p.grad).item()))
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
@ -183,13 +192,14 @@ class Trainer(object):
return step
def valid_model(self):
if dev_data is None:
if self.dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
res = self.validator.test(network, dev_data)
res = self.validator.test(self._model, self.dev_data)
if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res))
print('save best result! {}'.format(res))
self.save_model(self._model, 'best_model_'+self.start_time)
return res
@ -282,14 +292,10 @@ class Trainer(object):
"""
if isinstance(metrics, tuple):
loss, metrics = metrics
else:
metrics = validator.metrics
if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
elif self.eval_sort_key is None:
raise ValueError('dict format metrics should provide sort key for eval best result')
else:
accuracy = metrics[self.eval_sort_key]
else:

View File

@ -199,6 +199,8 @@ class BiaffineParser(GraphParser):
word_emb_dim,
pos_vocab_size,
pos_emb_dim,
word_hid_dim,
pos_hid_dim,
rnn_layers,
rnn_hidden_size,
arc_mlp_size,
@ -209,10 +211,15 @@ class BiaffineParser(GraphParser):
use_greedy_infer=False):
super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim)
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim)
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim)
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim)
self.word_norm = nn.LayerNorm(word_hid_dim)
self.pos_norm = nn.LayerNorm(pos_hid_dim)
if use_var_lstm:
self.lstm = VarLSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
@ -221,7 +228,7 @@ class BiaffineParser(GraphParser):
hidden_dropout=dropout,
bidirectional=True)
else:
self.lstm = nn.LSTM(input_size=word_emb_dim + pos_emb_dim,
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
@ -229,12 +236,13 @@ class BiaffineParser(GraphParser):
dropout=dropout,
bidirectional=True)
rnn_out_size = 2 * rnn_hidden_size
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size),
nn.LayerNorm(arc_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp)
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size),
nn.LayerNorm(label_mlp_size),
nn.ELU(),
TimestepDropout(p=dropout),)
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp)
@ -242,10 +250,18 @@ class BiaffineParser(GraphParser):
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.normal_dropout = nn.Dropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)
self.word_norm = nn.LayerNorm(word_emb_dim)
self.pos_norm = nn.LayerNorm(pos_emb_dim)
self.lstm_norm = nn.LayerNorm(rnn_out_size)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding):
continue
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
else:
for p in m.parameters():
nn.init.normal_(p, 0, 0.01)
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
@ -262,19 +278,21 @@ class BiaffineParser(GraphParser):
# prepare embeddings
batch_size, seq_len = word_seq.shape
# print('forward {} {}'.format(batch_size, seq_len))
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
# get sequence mask
seq_mask = len2masks(word_seq_origin_len, seq_len).long()
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
word, pos = self.word_fc(word), self.pos_fc(pos)
word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
del word, pos
# lstm, extract features
x = nn.utils.rnn.pack_padded_sequence(x, word_seq_origin_len.squeeze(1), batch_first=True)
feat, _ = self.lstm(x) # -> [N,L,C]
feat = self.lstm_norm(feat)
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
# for arc biaffine
# mlp, reduce dim
@ -282,6 +300,7 @@ class BiaffineParser(GraphParser):
arc_head = self.arc_head_mlp(feat)
label_dep = self.label_dep_mlp(feat)
label_head = self.label_head_mlp(feat)
del feat
# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]
@ -289,7 +308,7 @@ class BiaffineParser(GraphParser):
arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
# use gold or predicted arc to predict label
if gold_heads is None:
if gold_heads is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self._greedy_decoder(arc_pred, seq_mask)
@ -301,6 +320,7 @@ class BiaffineParser(GraphParser):
head_pred = None
heads = gold_heads
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label]
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'seq_mask': seq_mask}

View File

@ -1,16 +1,14 @@
[train]
epochs = -1
<<<<<<< HEAD
batch_size = 16
=======
batch_size = 32
>>>>>>> update biaffine
pickle_path = "./save/"
validate = true
save_best_dev = true
eval_sort_key = "UAS"
use_cuda = true
model_saved_path = "./save/"
print_every_step = 20
use_golden_train=true
[test]
save_output = true
@ -26,14 +24,17 @@ word_vocab_size = -1
word_emb_dim = 100
pos_vocab_size = -1
pos_emb_dim = 100
word_hid_dim = 100
pos_hid_dim = 100
rnn_layers = 3
rnn_hidden_size = 400
arc_mlp_size = 500
label_mlp_size = 100
num_label = -1
dropout = 0.33
use_var_lstm=true
use_var_lstm=false
use_greedy_infer=false
[optim]
lr = 2e-3
weight_decay = 0.0

View File

@ -6,6 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from collections import defaultdict
import math
import torch
import re
from fastNLP.core.trainer import Trainer
from fastNLP.core.metrics import Evaluator
@ -55,10 +56,10 @@ class ConlluDataLoader(object):
return ds
def get_one(self, sample):
text = ['<root>']
pos_tags = ['<root>']
heads = [0]
head_tags = ['root']
text = []
pos_tags = []
heads = []
head_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
@ -96,12 +97,13 @@ class CTBDataLoader(object):
def convert(self, data):
dataset = DataSet()
for sample in data:
word_seq = ["<ROOT>"] + sample[0]
pos_seq = ["<ROOT>"] + sample[1]
heads = [0] + list(map(int, sample[2]))
head_tags = ["ROOT"] + sample[3]
word_seq = ["<s>"] + sample[0] + ['</s>']
pos_seq = ["<s>"] + sample[1] + ['</s>']
heads = [0] + list(map(int, sample[2])) + [0]
head_tags = ["<s>"] + sample[3] + ['</s>']
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False),
pos_seq=TextField(pos_seq, is_target=False),
gold_heads=SeqLabelField(heads, is_target=False),
head_indices=SeqLabelField(heads, is_target=True),
head_labels=TextField(head_tags, is_target=True)))
return dataset
@ -117,7 +119,8 @@ datadir = '/home/yfshao/workdir/parser-data/'
train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt"
test_data_name = "test_ctb5.txt"
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt"
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec"
loader = CTBDataLoader()
cfgfile = './cfg.cfg'
@ -129,6 +132,10 @@ test_args = ConfigSection()
model_args = ConfigSection()
optim_args = ConfigSection()
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args})
print('trainre Args:', train_args.data)
print('test Args:', test_args.data)
print('optim Args:', optim_args.data)
# Pickle Loader
def save_data(dirpath, **kwargs):
@ -151,9 +158,31 @@ def load_data(dirpath):
datas[name] = _pickle.load(f)
return datas
def P2(data, field, length):
ds = [ins for ins in data if ins[field].get_length() >= length]
data.clear()
data.extend(ds)
return ds
def P1(data, field):
def reeng(w):
return w if w == '<s>' or w == '</s>' or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else 'ENG'
def renum(w):
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else 'NUMBER'
for ins in data:
ori = ins[field].contents()
s = list(map(renum, map(reeng, ori)))
if s != ori:
# print(ori)
# print(s)
# print()
ins[field] = ins[field].new(s)
return data
class ParserEvaluator(Evaluator):
def __init__(self):
def __init__(self, ignore_label):
super(ParserEvaluator, self).__init__()
self.ignore = ignore_label
def __call__(self, predict_list, truth_list):
head_all, label_all, total_all = 0, 0, 0
@ -174,6 +203,7 @@ class ParserEvaluator(Evaluator):
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
seq_mask *= (head_labels != self.ignore).long()
head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
@ -181,72 +211,93 @@ class ParserEvaluator(Evaluator):
try:
data_dict = load_data(processed_datadir)
word_v = data_dict['word_v']
pos_v = data_dict['pos_v']
tag_v = data_dict['tag_v']
train_data = data_dict['train_data']
dev_data = data_dict['dev_data']
test_data = data_dict['test_datas']
print('use saved pickles')
except Exception as _:
print('load raw data and preprocess')
word_v = Vocabulary(need_default=True, min_freq=2)
# use pretrain embedding
pos_v = Vocabulary(need_default=True)
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name))
test_data = loader.load(os.path.join(datadir, test_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
train_data.update_vocab(pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data)
train_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
dev_data.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
train_data.set_origin_len("word_seq")
dev_data.set_origin_len("word_seq")
embed, word_v = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', None, os.path.join(processed_datadir, 'word_emb.pkl'))
word_v.unknown_label = "<OOV>"
print(train_data[:3])
print(len(train_data))
print(len(dev_data))
# Model
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)
model = BiaffineParser(**model_args.data)
model.reset_parameters()
def train():
datasets = (train_data, dev_data, test_data)
for ds in datasets:
# print('====='*30)
P1(ds, 'word_seq')
P2(ds, 'word_seq', 5)
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v)
ds.set_origin_len('word_seq')
if train_args['use_golden_train']:
ds.set_target(gold_heads=False)
else:
ds.set_target(gold_heads=None)
train_args.data.pop('use_golden_train')
ignore_label = pos_v['P']
print(test_data[0])
print(len(train_data))
print(len(dev_data))
print(len(test_data))
def train(path):
# Trainer
trainer = Trainer(**train_args.data)
def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
lr = optim_args.data['lr']
embed_params = set(obj._model.word_embedding.parameters())
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters())
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params]
obj._optimizer = torch.optim.Adam([
{'params': list(embed_params), 'lr':lr*0.1},
{'params': list(decay_params), **optim_args.data},
{'params': params}
], lr=lr)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))
def _update(obj):
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0)
obj._scheduler.step()
obj._optimizer.step()
trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer)
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator()))
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)))
# Model
model = BiaffineParser(**model_args.data)
# use pretrain embedding
word_v.unknown_label = "<OOV>"
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0)
try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as _:
print("No saved model. Continue.")
pass
# try:
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
# print('model parameter loaded!')
# except Exception as _:
# print("No saved model. Continue.")
# pass
# Start training
trainer.train(model, train_data, dev_data)
@ -258,15 +309,15 @@ def train():
print("Model saved!")
def test():
def test(path):
# Tester
tester = Tester(**test_args.data, evaluator=ParserEvaluator())
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))
# Model
model = BiaffineParser(**model_args.data)
try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
ModelLoader.load_pytorch(model, path)
print('model parameter loaded!')
except Exception as _:
print("No saved model. Abort test.")
@ -284,11 +335,12 @@ if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
parser.add_argument('--path', type=str, default='')
args = parser.parse_args()
if args.mode == 'train':
train()
train(args.path)
elif args.mode == 'test':
test()
test(args.path)
elif args.mode == 'infer':
infer()
else: