update biaffine

This commit is contained in:
yunfan 2018-10-31 10:53:23 +08:00
parent 96a2794fdf
commit c14d9f4d66
5 changed files with 136 additions and 113 deletions

View File

@ -17,9 +17,9 @@ class Tester(object):
"""
super(Tester, self).__init__()
"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"batch_size": 8,
@ -29,8 +29,8 @@ class Tester(object):
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}
@ -76,14 +76,17 @@ class Tester(object):
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
for batch_x, batch_y in data_iterator:
with torch.no_grad():
with torch.no_grad():
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
print("[tester] {}".format(self.print_eval_results(eval_results)))
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
self.metrics = eval_results
return eval_results
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

View File

@ -35,20 +35,21 @@ class Trainer(object):
super(Trainer, self).__init__()
"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"valid_step": 500, "eval_sort_key": None,
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}
@ -70,16 +71,20 @@ class Trainer(object):
else:
# Trainer doesn't care about extra arguments
pass
print(default_args)
print("Training Args {}".format(default_args))
logger.info("Training Args {}".format(default_args))
self.n_epochs = default_args["epochs"]
self.batch_size = default_args["batch_size"]
self.n_epochs = int(default_args["epochs"])
self.batch_size = int(default_args["batch_size"])
self.pickle_path = default_args["pickle_path"]
self.validate = default_args["validate"]
self.save_best_dev = default_args["save_best_dev"]
self.use_cuda = default_args["use_cuda"]
self.model_name = default_args["model_name"]
self.print_every_step = default_args["print_every_step"]
self.print_every_step = int(default_args["print_every_step"])
self.valid_step = int(default_args["valid_step"])
if self.validate is not None:
assert self.valid_step > 0
self._model = None
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
@ -89,6 +94,8 @@ class Trainer(object):
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
self._best_accuracy = 0.0
self.eval_sort_key = default_args['eval_sort_key']
self.validator = None
def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@ -108,8 +115,9 @@ class Trainer(object):
if self.validate:
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))
if self.validator is None:
self.validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(self.validator)))
# optimizer and loss
self.define_optimizer()
@ -117,29 +125,31 @@ class Trainer(object):
self.define_loss()
logger.info("loss function defined as {}".format(str(self._loss_func)))
# turn on network training mode
self.mode(network, is_test=False)
# main training procedure
start = time.time()
logger.info("training epochs started")
for epoch in range(1, self.n_epochs + 1):
self.start_time = str(start)
logger.info("training epochs started " + self.start_time)
epoch, iters = 1, 0
while(1):
if self.n_epochs != -1 and epoch > self.n_epochs:
break
logger.info("training epoch {}".format(epoch))
# turn on network training mode
self.mode(network, is_test=False)
# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
logger.info("prepared data iterator")
# one forward and backward pass
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch)
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
# validation
if self.validate:
if dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
validator.test(network, dev_data)
self.valid_model()
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch.
@ -149,7 +159,8 @@ class Trainer(object):
- start: time.time(), the starting time of this step.
- epoch: int,
"""
step = 0
step = kwargs['step']
dev_data = kwargs['dev_data']
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
@ -166,7 +177,21 @@ class Trainer(object):
kwargs["epoch"], step, loss.data, diff)
print(print_output)
logger.info(print_output)
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0:
self.valid_model()
step += 1
return step
def valid_model(self):
if dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
res = self.validator.test(network, dev_data)
if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res))
self.save_model(self._model, 'best_model_'+self.start_time)
return res
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@ -180,11 +205,17 @@ class Trainer(object):
else:
model.train()
def define_optimizer(self):
def define_optimizer(self, optim=None):
"""Define framework-specific optimizer specified by the models.
"""
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
if optim is not None:
# optimizer constructed by user
self._optimizer = optim
elif self._optimizer is None:
# optimizer constructed by proto
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
return self._optimizer
def update(self):
"""Perform weight update on a model.
@ -217,6 +248,8 @@ class Trainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
if isinstance(predict, dict) and isinstance(truth, dict):
return self._loss_func(**predict, **truth)
if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None
@ -241,13 +274,27 @@ class Trainer(object):
raise ValueError("Please specify a loss function.")
logger.info("The model didn't define loss, use Trainer's loss.")
def best_eval_result(self, validator):
def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.
:param validator: a Tester instance
:return: bool, True means current results on dev set is the best.
"""
loss, accuracy = validator.metrics
if isinstance(metrics, tuple):
loss, metrics = metrics
else:
metrics = validator.metrics
if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
elif self.eval_sort_key is None:
raise ValueError('dict format metrics should provide sort key for eval best result')
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics
if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
@ -268,6 +315,8 @@ class Trainer(object):
def _create_validator(self, valid_args):
raise NotImplementedError
def set_validator(self, validor):
self.validator = validor
class SeqLabelTrainer(Trainer):
"""Trainer for Sequence Labeling

View File

@ -243,6 +243,9 @@ class BiaffineParser(GraphParser):
self.normal_dropout = nn.Dropout(p=dropout)
self.use_greedy_infer = use_greedy_infer
initial_parameter(self)
self.word_norm = nn.LayerNorm(word_emb_dim)
self.pos_norm = nn.LayerNorm(pos_emb_dim)
self.lstm_norm = nn.LayerNorm(rnn_out_size)
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
"""
@ -266,10 +269,12 @@ class BiaffineParser(GraphParser):
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
word, pos = self.word_norm(word), self.pos_norm(pos)
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
# lstm, extract features
feat, _ = self.lstm(x) # -> [N,L,C]
feat = self.lstm_norm(feat)
# for arc biaffine
# mlp, reduce dim
@ -292,6 +297,7 @@ class BiaffineParser(GraphParser):
heads = self._mst_decoder(arc_pred, seq_mask)
head_pred = heads
else:
assert self.training # must be training mode
head_pred = None
heads = gold_heads
@ -331,40 +337,4 @@ class BiaffineParser(GraphParser):
label_nll = -(label_loss*float_mask).sum() / length
return arc_nll + label_nll
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs):
"""
Evaluate the performance of prediction.
:return dict: performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
if 'head_pred' in kwargs:
head_pred = kwargs['head_pred']
elif self.use_greedy_infer:
head_pred = self._greedy_decoder(arc_pred, seq_mask)
else:
head_pred = self._mst_decoder(arc_pred, seq_mask)
head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return {"head_pred_correct": head_pred_correct.sum(dim=1),
"label_pred_correct": label_pred_correct.sum(dim=1),
"total_tokens": seq_mask.sum(dim=1)}
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_):
"""
Compute the metrics of model
:param head_pred_corrct: number of correct predicted heads.
:param label_pred_correct: number of correct predicted labels.
:param total_tokens: number of predicted tokens
:return dict: the metrics results
UAS: the head predicted accuracy
LAS: the label predicted accuracy
"""
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100,
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100}

View File

@ -1,23 +1,25 @@
[train]
epochs = -1
<<<<<<< HEAD
batch_size = 16
=======
batch_size = 32
>>>>>>> update biaffine
pickle_path = "./save/"
validate = true
save_best_dev = false
save_best_dev = true
eval_sort_key = "UAS"
use_cuda = true
model_saved_path = "./save/"
task = "parse"
[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 16
batch_size = 64
pickle_path = "./save/"
use_cuda = true
task = "parse"
[model]
word_vocab_size = -1

View File

@ -8,12 +8,14 @@ import math
import torch
from fastNLP.core.trainer import Trainer
from fastNLP.core.metrics import Evaluator
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
@ -111,9 +113,10 @@ class CTBDataLoader(object):
# emb_file_name = '/home/yfshao/glove.6B.100d.txt'
# loader = ConlluDataLoader()
datadir = "/home/yfshao/parser-data"
datadir = '/home/yfshao/workdir/parser-data/'
train_data_name = "train_ctb5.txt"
dev_data_name = "dev_ctb5.txt"
test_data_name = "test_ctb5.txt"
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
loader = CTBDataLoader()
@ -148,37 +151,33 @@ def load_data(dirpath):
datas[name] = _pickle.load(f)
return datas
class MyTester(object):
def __init__(self, batch_size, use_cuda=False, **kwagrs):
self.batch_size = batch_size
self.use_cuda = use_cuda
class ParserEvaluator(Evaluator):
def __init__(self):
super(ParserEvaluator, self).__init__()
def test(self, model, dataset):
self.model = model.cuda() if self.use_cuda else model
self.model.eval()
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
eval_res = defaultdict(list)
i = 0
for batch_x, batch_y in batchiter:
with torch.no_grad():
pred_y = self.model(**batch_x)
eval_one = self.model.evaluate(**pred_y, **batch_y)
i += self.batch_size
for eval_name, tensor in eval_one.items():
eval_res[eval_name].append(tensor)
tmp = {}
for eval_name, tensorlist in eval_res.items():
tmp[eval_name] = torch.cat(tensorlist, dim=0)
def __call__(self, predict_list, truth_list):
head_all, label_all, total_all = 0, 0, 0
for pred, truth in zip(predict_list, truth_list):
head, label, total = self.evaluate(**pred, **truth)
head_all += head
label_all += label
total_all += total
self.res = self.model.metrics(**tmp)
print(self.show_metrics())
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all}
def show_metrics(self):
s = ""
for name, val in self.res.items():
s += '{}: {:.2f}\t'.format(name, val)
return s
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_):
"""
Evaluate the performance of prediction.
:return : performance results.
head_pred_corrct: number of correct predicted heads.
label_pred_correct: number of correct predicted labels.
total_tokens: number of predicted tokens
"""
head_pred_correct = (head_pred == head_indices).long() * seq_mask
_, label_preds = torch.max(label_pred, dim=2)
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item()
try:
data_dict = load_data(processed_datadir)
@ -196,6 +195,7 @@ except Exception as _:
tag_v = Vocabulary(need_default=False)
train_data = loader.load(os.path.join(datadir, train_data_name))
dev_data = loader.load(os.path.join(datadir, dev_data_name))
test_data = loader.load(os.path.join(datadir, test_data_name))
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
@ -207,8 +207,6 @@ dev_data.set_origin_len("word_seq")
print(train_data[:3])
print(len(train_data))
print(len(dev_data))
ep = train_args['epochs']
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
model_args['word_vocab_size'] = len(word_v)
model_args['pos_vocab_size'] = len(pos_v)
model_args['num_label'] = len(tag_v)
@ -220,7 +218,7 @@ def train():
def _define_optim(obj):
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))
def _update(obj):
obj._scheduler.step()
@ -228,8 +226,7 @@ def train():
trainer.define_optimizer = lambda: _define_optim(trainer)
trainer.update = lambda: _update(trainer)
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
trainer._create_validator = lambda x: MyTester(**test_args.data)
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator()))
# Model
model = BiaffineParser(**model_args.data)
@ -238,6 +235,7 @@ def train():
word_v.unknown_label = "<OOV>"
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
model.word_embedding.padding_idx = word_v.padding_idx
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
model.pos_embedding.padding_idx = pos_v.padding_idx
@ -262,7 +260,7 @@ def train():
def test():
# Tester
tester = MyTester(**test_args.data)
tester = Tester(**test_args.data, evaluator=ParserEvaluator())
# Model
model = BiaffineParser(**model_args.data)
@ -275,9 +273,10 @@ def test():
raise
# Start training
print("Testing Dev data")
tester.test(model, dev_data)
print(tester.show_metrics())
print("Testing finished!")
print("Testing Test data")
tester.test(model, test_data)