- add validation loss into trainer.train

- restructure: move reproduction outside
- add evaluate in tester
This commit is contained in:
FengZiYjun 2018-07-11 21:51:35 +08:00
parent a73087e913
commit 7514be6f30
30 changed files with 29 additions and 9 deletions

View File

@ -1,5 +1,6 @@
import _pickle
import numpy as np
import torch
from fastNLP.action.action import Action
@ -16,8 +17,7 @@ class BaseTester(Action):
"""
super(BaseTester, self).__init__()
self.validate_in_training = test_args["validate_in_training"]
self.valid_x = None
self.valid_y = None
self.save_dev_data = None
self.save_output = test_args["save_output"]
self.output = None
self.save_loss = test_args["save_loss"]
@ -26,8 +26,14 @@ class BaseTester(Action):
self.pickle_path = test_args["pickle_path"]
self.iterator = None
self.model = None
self.eval_history = []
def test(self, network):
# print("--------------testing----------------")
self.model = network
# turn on the testing mode; clean up the history
self.mode(network, test=True)
dev_data = self.prepare_input(self.pickle_path)
@ -35,7 +41,6 @@ class BaseTester(Action):
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))
batch_output = list()
eval_history = list()
num_iter = len(dev_data) // self.batch_size
for step in range(num_iter):
@ -47,11 +52,18 @@ class BaseTester(Action):
if self.save_output:
batch_output.append(prediction)
if self.save_loss:
eval_history.append(eval_results)
self.eval_history.append(eval_results)
def prepare_input(self, data_path):
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_dev
"""
Save the dev data once it is loaded. Can return directly next time.
:param data_path: str, the path to the pickle data for dev
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
"""
if self.save_dev_data is None:
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
self.save_dev_data = data_dev
return self.save_dev_data
def batchify(self, data):
"""
@ -99,11 +111,12 @@ class BaseTester(Action):
raise NotImplementedError
def mode(self, model, test=True):
"""To do: combine this function with Trainer"""
"""To do: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()
class POSTester(BaseTester):
@ -115,6 +128,7 @@ class POSTester(BaseTester):
super(POSTester, self).__init__(test_args)
self.max_len = None
self.mask = None
self.batch_result = None
def data_forward(self, network, x):
"""To Do: combine with Trainer
@ -132,5 +146,9 @@ class POSTester(BaseTester):
return y
def evaluate(self, predict, truth):
"""To Do: """
return 0
truth = torch.Tensor(truth)
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len)
return loss.data
def matrices(self):
return np.mean(self.eval_history)

View File

@ -89,6 +89,7 @@ class BaseTrainer(Action):
if data_dev is None:
raise RuntimeError("No validation data provided.")
validator.test(network)
print("[epoch {}] dev loss={:.2f}".format(epoch, validator.matrices()))
# finish training
@ -386,6 +387,7 @@ class POSTrainer(BaseTrainer):
else:
self.define_loss()
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
# print("loss={:.2f}".format(loss.data))
return loss