mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
update biaffine
This commit is contained in:
parent
96a2794fdf
commit
c14d9f4d66
@ -17,9 +17,9 @@ class Tester(object):
|
||||
"""
|
||||
super(Tester, self).__init__()
|
||||
"""
|
||||
"default_args" provides default value for important settings.
|
||||
The initialization arguments "kwargs" with the same key (name) will override the default value.
|
||||
"kwargs" must have the same type as "default_args" on corresponding keys.
|
||||
"default_args" provides default value for important settings.
|
||||
The initialization arguments "kwargs" with the same key (name) will override the default value.
|
||||
"kwargs" must have the same type as "default_args" on corresponding keys.
|
||||
Otherwise, error will raise.
|
||||
"""
|
||||
default_args = {"batch_size": 8,
|
||||
@ -29,8 +29,8 @@ class Tester(object):
|
||||
"evaluator": Evaluator()
|
||||
}
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
required_args = {}
|
||||
@ -76,14 +76,17 @@ class Tester(object):
|
||||
|
||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
|
||||
|
||||
for batch_x, batch_y in data_iterator:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y in data_iterator:
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
output_list.append(prediction)
|
||||
truth_list.append(batch_y)
|
||||
eval_results = self.evaluate(output_list, truth_list)
|
||||
output_list.append(prediction)
|
||||
truth_list.append(batch_y)
|
||||
eval_results = self.evaluate(output_list, truth_list)
|
||||
print("[tester] {}".format(self.print_eval_results(eval_results)))
|
||||
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))
|
||||
self.mode(network, is_test=False)
|
||||
self.metrics = eval_results
|
||||
return eval_results
|
||||
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
@ -35,20 +35,21 @@ class Trainer(object):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
"""
|
||||
"default_args" provides default value for important settings.
|
||||
The initialization arguments "kwargs" with the same key (name) will override the default value.
|
||||
"kwargs" must have the same type as "default_args" on corresponding keys.
|
||||
"default_args" provides default value for important settings.
|
||||
The initialization arguments "kwargs" with the same key (name) will override the default value.
|
||||
"kwargs" must have the same type as "default_args" on corresponding keys.
|
||||
Otherwise, error will raise.
|
||||
"""
|
||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
|
||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
|
||||
"valid_step": 500, "eval_sort_key": None,
|
||||
"loss": Loss(None), # used to pass type check
|
||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
|
||||
"evaluator": Evaluator()
|
||||
}
|
||||
"""
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
|
||||
This is used to warn users of essential settings in the training.
|
||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
|
||||
"""
|
||||
required_args = {}
|
||||
@ -70,16 +71,20 @@ class Trainer(object):
|
||||
else:
|
||||
# Trainer doesn't care about extra arguments
|
||||
pass
|
||||
print(default_args)
|
||||
print("Training Args {}".format(default_args))
|
||||
logger.info("Training Args {}".format(default_args))
|
||||
|
||||
self.n_epochs = default_args["epochs"]
|
||||
self.batch_size = default_args["batch_size"]
|
||||
self.n_epochs = int(default_args["epochs"])
|
||||
self.batch_size = int(default_args["batch_size"])
|
||||
self.pickle_path = default_args["pickle_path"]
|
||||
self.validate = default_args["validate"]
|
||||
self.save_best_dev = default_args["save_best_dev"]
|
||||
self.use_cuda = default_args["use_cuda"]
|
||||
self.model_name = default_args["model_name"]
|
||||
self.print_every_step = default_args["print_every_step"]
|
||||
self.print_every_step = int(default_args["print_every_step"])
|
||||
self.valid_step = int(default_args["valid_step"])
|
||||
if self.validate is not None:
|
||||
assert self.valid_step > 0
|
||||
|
||||
self._model = None
|
||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
|
||||
@ -89,6 +94,8 @@ class Trainer(object):
|
||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
|
||||
self._graph_summaried = False
|
||||
self._best_accuracy = 0.0
|
||||
self.eval_sort_key = default_args['eval_sort_key']
|
||||
self.validator = None
|
||||
|
||||
def train(self, network, train_data, dev_data=None):
|
||||
"""General Training Procedure
|
||||
@ -108,8 +115,9 @@ class Trainer(object):
|
||||
if self.validate:
|
||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
|
||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
|
||||
validator = self._create_validator(default_valid_args)
|
||||
logger.info("validator defined as {}".format(str(validator)))
|
||||
if self.validator is None:
|
||||
self.validator = self._create_validator(default_valid_args)
|
||||
logger.info("validator defined as {}".format(str(self.validator)))
|
||||
|
||||
# optimizer and loss
|
||||
self.define_optimizer()
|
||||
@ -117,29 +125,31 @@ class Trainer(object):
|
||||
self.define_loss()
|
||||
logger.info("loss function defined as {}".format(str(self._loss_func)))
|
||||
|
||||
# turn on network training mode
|
||||
self.mode(network, is_test=False)
|
||||
|
||||
# main training procedure
|
||||
start = time.time()
|
||||
logger.info("training epochs started")
|
||||
for epoch in range(1, self.n_epochs + 1):
|
||||
self.start_time = str(start)
|
||||
|
||||
logger.info("training epochs started " + self.start_time)
|
||||
epoch, iters = 1, 0
|
||||
while(1):
|
||||
if self.n_epochs != -1 and epoch > self.n_epochs:
|
||||
break
|
||||
logger.info("training epoch {}".format(epoch))
|
||||
|
||||
# turn on network training mode
|
||||
self.mode(network, is_test=False)
|
||||
# prepare mini-batch iterator
|
||||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
|
||||
use_cuda=self.use_cuda)
|
||||
logger.info("prepared data iterator")
|
||||
|
||||
# one forward and backward pass
|
||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch)
|
||||
iters += self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch, step=iters, dev_data=dev_data)
|
||||
|
||||
# validation
|
||||
if self.validate:
|
||||
if dev_data is None:
|
||||
raise RuntimeError(
|
||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
|
||||
logger.info("validation started")
|
||||
validator.test(network, dev_data)
|
||||
self.valid_model()
|
||||
|
||||
def _train_step(self, data_iterator, network, **kwargs):
|
||||
"""Training process in one epoch.
|
||||
@ -149,7 +159,8 @@ class Trainer(object):
|
||||
- start: time.time(), the starting time of this step.
|
||||
- epoch: int,
|
||||
"""
|
||||
step = 0
|
||||
step = kwargs['step']
|
||||
dev_data = kwargs['dev_data']
|
||||
for batch_x, batch_y in data_iterator:
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
@ -166,7 +177,21 @@ class Trainer(object):
|
||||
kwargs["epoch"], step, loss.data, diff)
|
||||
print(print_output)
|
||||
logger.info(print_output)
|
||||
if self.validate and self.valid_step > 0 and step > 0 and step % self.valid_step == 0:
|
||||
self.valid_model()
|
||||
step += 1
|
||||
return step
|
||||
|
||||
def valid_model(self):
|
||||
if dev_data is None:
|
||||
raise RuntimeError(
|
||||
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
|
||||
logger.info("validation started")
|
||||
res = self.validator.test(network, dev_data)
|
||||
if self.save_best_dev and self.best_eval_result(res):
|
||||
logger.info('save best result! {}'.format(res))
|
||||
self.save_model(self._model, 'best_model_'+self.start_time)
|
||||
return res
|
||||
|
||||
def mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
@ -180,11 +205,17 @@ class Trainer(object):
|
||||
else:
|
||||
model.train()
|
||||
|
||||
def define_optimizer(self):
|
||||
def define_optimizer(self, optim=None):
|
||||
"""Define framework-specific optimizer specified by the models.
|
||||
|
||||
"""
|
||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
|
||||
if optim is not None:
|
||||
# optimizer constructed by user
|
||||
self._optimizer = optim
|
||||
elif self._optimizer is None:
|
||||
# optimizer constructed by proto
|
||||
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
|
||||
return self._optimizer
|
||||
|
||||
def update(self):
|
||||
"""Perform weight update on a model.
|
||||
@ -217,6 +248,8 @@ class Trainer(object):
|
||||
:param truth: ground truth label vector
|
||||
:return: a scalar
|
||||
"""
|
||||
if isinstance(predict, dict) and isinstance(truth, dict):
|
||||
return self._loss_func(**predict, **truth)
|
||||
if len(truth) > 1:
|
||||
raise NotImplementedError("Not ready to handle multi-labels.")
|
||||
truth = list(truth.values())[0] if len(truth) > 0 else None
|
||||
@ -241,13 +274,27 @@ class Trainer(object):
|
||||
raise ValueError("Please specify a loss function.")
|
||||
logger.info("The model didn't define loss, use Trainer's loss.")
|
||||
|
||||
def best_eval_result(self, validator):
|
||||
def best_eval_result(self, metrics):
|
||||
"""Check if the current epoch yields better validation results.
|
||||
|
||||
:param validator: a Tester instance
|
||||
:return: bool, True means current results on dev set is the best.
|
||||
"""
|
||||
loss, accuracy = validator.metrics
|
||||
if isinstance(metrics, tuple):
|
||||
loss, metrics = metrics
|
||||
else:
|
||||
metrics = validator.metrics
|
||||
|
||||
if isinstance(metrics, dict):
|
||||
if len(metrics) == 1:
|
||||
accuracy = list(metrics.values())[0]
|
||||
elif self.eval_sort_key is None:
|
||||
raise ValueError('dict format metrics should provide sort key for eval best result')
|
||||
else:
|
||||
accuracy = metrics[self.eval_sort_key]
|
||||
else:
|
||||
accuracy = metrics
|
||||
|
||||
if accuracy > self._best_accuracy:
|
||||
self._best_accuracy = accuracy
|
||||
return True
|
||||
@ -268,6 +315,8 @@ class Trainer(object):
|
||||
def _create_validator(self, valid_args):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_validator(self, validor):
|
||||
self.validator = validor
|
||||
|
||||
class SeqLabelTrainer(Trainer):
|
||||
"""Trainer for Sequence Labeling
|
||||
|
@ -243,6 +243,9 @@ class BiaffineParser(GraphParser):
|
||||
self.normal_dropout = nn.Dropout(p=dropout)
|
||||
self.use_greedy_infer = use_greedy_infer
|
||||
initial_parameter(self)
|
||||
self.word_norm = nn.LayerNorm(word_emb_dim)
|
||||
self.pos_norm = nn.LayerNorm(pos_emb_dim)
|
||||
self.lstm_norm = nn.LayerNorm(rnn_out_size)
|
||||
|
||||
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
|
||||
"""
|
||||
@ -266,10 +269,12 @@ class BiaffineParser(GraphParser):
|
||||
|
||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
|
||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
|
||||
word, pos = self.word_norm(word), self.pos_norm(pos)
|
||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C]
|
||||
|
||||
# lstm, extract features
|
||||
feat, _ = self.lstm(x) # -> [N,L,C]
|
||||
feat = self.lstm_norm(feat)
|
||||
|
||||
# for arc biaffine
|
||||
# mlp, reduce dim
|
||||
@ -292,6 +297,7 @@ class BiaffineParser(GraphParser):
|
||||
heads = self._mst_decoder(arc_pred, seq_mask)
|
||||
head_pred = heads
|
||||
else:
|
||||
assert self.training # must be training mode
|
||||
head_pred = None
|
||||
heads = gold_heads
|
||||
|
||||
@ -331,40 +337,4 @@ class BiaffineParser(GraphParser):
|
||||
label_nll = -(label_loss*float_mask).sum() / length
|
||||
return arc_nll + label_nll
|
||||
|
||||
def evaluate(self, arc_pred, label_pred, head_indices, head_labels, seq_mask, **kwargs):
|
||||
"""
|
||||
Evaluate the performance of prediction.
|
||||
|
||||
:return dict: performance results.
|
||||
head_pred_corrct: number of correct predicted heads.
|
||||
label_pred_correct: number of correct predicted labels.
|
||||
total_tokens: number of predicted tokens
|
||||
"""
|
||||
if 'head_pred' in kwargs:
|
||||
head_pred = kwargs['head_pred']
|
||||
elif self.use_greedy_infer:
|
||||
head_pred = self._greedy_decoder(arc_pred, seq_mask)
|
||||
else:
|
||||
head_pred = self._mst_decoder(arc_pred, seq_mask)
|
||||
|
||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask
|
||||
_, label_preds = torch.max(label_pred, dim=2)
|
||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
|
||||
return {"head_pred_correct": head_pred_correct.sum(dim=1),
|
||||
"label_pred_correct": label_pred_correct.sum(dim=1),
|
||||
"total_tokens": seq_mask.sum(dim=1)}
|
||||
|
||||
def metrics(self, head_pred_correct, label_pred_correct, total_tokens, **_):
|
||||
"""
|
||||
Compute the metrics of model
|
||||
|
||||
:param head_pred_corrct: number of correct predicted heads.
|
||||
:param label_pred_correct: number of correct predicted labels.
|
||||
:param total_tokens: number of predicted tokens
|
||||
:return dict: the metrics results
|
||||
UAS: the head predicted accuracy
|
||||
LAS: the label predicted accuracy
|
||||
"""
|
||||
return {"UAS": head_pred_correct.sum().float() / total_tokens.sum().float() * 100,
|
||||
"LAS": label_pred_correct.sum().float() / total_tokens.sum().float() * 100}
|
||||
|
||||
|
@ -1,23 +1,25 @@
|
||||
[train]
|
||||
epochs = -1
|
||||
<<<<<<< HEAD
|
||||
batch_size = 16
|
||||
=======
|
||||
batch_size = 32
|
||||
>>>>>>> update biaffine
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
save_best_dev = false
|
||||
save_best_dev = true
|
||||
eval_sort_key = "UAS"
|
||||
use_cuda = true
|
||||
model_saved_path = "./save/"
|
||||
task = "parse"
|
||||
|
||||
|
||||
[test]
|
||||
save_output = true
|
||||
validate_in_training = true
|
||||
save_dev_input = false
|
||||
save_loss = true
|
||||
batch_size = 16
|
||||
batch_size = 64
|
||||
pickle_path = "./save/"
|
||||
use_cuda = true
|
||||
task = "parse"
|
||||
|
||||
[model]
|
||||
word_vocab_size = -1
|
||||
|
@ -8,12 +8,14 @@ import math
|
||||
import torch
|
||||
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.core.field import TextField, SeqLabelField
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
@ -111,9 +113,10 @@ class CTBDataLoader(object):
|
||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt'
|
||||
# loader = ConlluDataLoader()
|
||||
|
||||
datadir = "/home/yfshao/parser-data"
|
||||
datadir = '/home/yfshao/workdir/parser-data/'
|
||||
train_data_name = "train_ctb5.txt"
|
||||
dev_data_name = "dev_ctb5.txt"
|
||||
test_data_name = "test_ctb5.txt"
|
||||
emb_file_name = "/home/yfshao/parser-data/word_OOVthr_30_100v.txt"
|
||||
loader = CTBDataLoader()
|
||||
|
||||
@ -148,37 +151,33 @@ def load_data(dirpath):
|
||||
datas[name] = _pickle.load(f)
|
||||
return datas
|
||||
|
||||
class MyTester(object):
|
||||
def __init__(self, batch_size, use_cuda=False, **kwagrs):
|
||||
self.batch_size = batch_size
|
||||
self.use_cuda = use_cuda
|
||||
class ParserEvaluator(Evaluator):
|
||||
def __init__(self):
|
||||
super(ParserEvaluator, self).__init__()
|
||||
|
||||
def test(self, model, dataset):
|
||||
self.model = model.cuda() if self.use_cuda else model
|
||||
self.model.eval()
|
||||
batchiter = Batch(dataset, self.batch_size, SequentialSampler(), self.use_cuda)
|
||||
eval_res = defaultdict(list)
|
||||
i = 0
|
||||
for batch_x, batch_y in batchiter:
|
||||
with torch.no_grad():
|
||||
pred_y = self.model(**batch_x)
|
||||
eval_one = self.model.evaluate(**pred_y, **batch_y)
|
||||
i += self.batch_size
|
||||
for eval_name, tensor in eval_one.items():
|
||||
eval_res[eval_name].append(tensor)
|
||||
tmp = {}
|
||||
for eval_name, tensorlist in eval_res.items():
|
||||
tmp[eval_name] = torch.cat(tensorlist, dim=0)
|
||||
def __call__(self, predict_list, truth_list):
|
||||
head_all, label_all, total_all = 0, 0, 0
|
||||
for pred, truth in zip(predict_list, truth_list):
|
||||
head, label, total = self.evaluate(**pred, **truth)
|
||||
head_all += head
|
||||
label_all += label
|
||||
total_all += total
|
||||
|
||||
self.res = self.model.metrics(**tmp)
|
||||
print(self.show_metrics())
|
||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all}
|
||||
|
||||
def show_metrics(self):
|
||||
s = ""
|
||||
for name, val in self.res.items():
|
||||
s += '{}: {:.2f}\t'.format(name, val)
|
||||
return s
|
||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_):
|
||||
"""
|
||||
Evaluate the performance of prediction.
|
||||
|
||||
:return : performance results.
|
||||
head_pred_corrct: number of correct predicted heads.
|
||||
label_pred_correct: number of correct predicted labels.
|
||||
total_tokens: number of predicted tokens
|
||||
"""
|
||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask
|
||||
_, label_preds = torch.max(label_pred, dim=2)
|
||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct
|
||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item()
|
||||
|
||||
try:
|
||||
data_dict = load_data(processed_datadir)
|
||||
@ -196,6 +195,7 @@ except Exception as _:
|
||||
tag_v = Vocabulary(need_default=False)
|
||||
train_data = loader.load(os.path.join(datadir, train_data_name))
|
||||
dev_data = loader.load(os.path.join(datadir, dev_data_name))
|
||||
test_data = loader.load(os.path.join(datadir, test_data_name))
|
||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v)
|
||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data)
|
||||
|
||||
@ -207,8 +207,6 @@ dev_data.set_origin_len("word_seq")
|
||||
print(train_data[:3])
|
||||
print(len(train_data))
|
||||
print(len(dev_data))
|
||||
ep = train_args['epochs']
|
||||
train_args['epochs'] = math.ceil(50000.0 / len(train_data) * train_args['batch_size']) if ep <= 0 else ep
|
||||
model_args['word_vocab_size'] = len(word_v)
|
||||
model_args['pos_vocab_size'] = len(pos_v)
|
||||
model_args['num_label'] = len(tag_v)
|
||||
@ -220,7 +218,7 @@ def train():
|
||||
|
||||
def _define_optim(obj):
|
||||
obj._optimizer = torch.optim.Adam(obj._model.parameters(), **optim_args.data)
|
||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: .75 ** (ep / 5e4))
|
||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05))
|
||||
|
||||
def _update(obj):
|
||||
obj._scheduler.step()
|
||||
@ -228,8 +226,7 @@ def train():
|
||||
|
||||
trainer.define_optimizer = lambda: _define_optim(trainer)
|
||||
trainer.update = lambda: _update(trainer)
|
||||
trainer.get_loss = lambda predict, truth: trainer._loss_func(**predict, **truth)
|
||||
trainer._create_validator = lambda x: MyTester(**test_args.data)
|
||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator()))
|
||||
|
||||
# Model
|
||||
model = BiaffineParser(**model_args.data)
|
||||
@ -238,6 +235,7 @@ def train():
|
||||
word_v.unknown_label = "<OOV>"
|
||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl'))
|
||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False)
|
||||
|
||||
model.word_embedding.padding_idx = word_v.padding_idx
|
||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0)
|
||||
model.pos_embedding.padding_idx = pos_v.padding_idx
|
||||
@ -262,7 +260,7 @@ def train():
|
||||
|
||||
def test():
|
||||
# Tester
|
||||
tester = MyTester(**test_args.data)
|
||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator())
|
||||
|
||||
# Model
|
||||
model = BiaffineParser(**model_args.data)
|
||||
@ -275,9 +273,10 @@ def test():
|
||||
raise
|
||||
|
||||
# Start training
|
||||
print("Testing Dev data")
|
||||
tester.test(model, dev_data)
|
||||
print(tester.show_metrics())
|
||||
print("Testing finished!")
|
||||
print("Testing Test data")
|
||||
tester.test(model, test_data)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user