mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
- add validation loss into trainer.train
- restructure: move reproduction outside - add evaluate in tester
This commit is contained in:
parent
a73087e913
commit
7514be6f30
@ -1,5 +1,6 @@
|
||||
import _pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.action.action import Action
|
||||
@ -16,8 +17,7 @@ class BaseTester(Action):
|
||||
"""
|
||||
super(BaseTester, self).__init__()
|
||||
self.validate_in_training = test_args["validate_in_training"]
|
||||
self.valid_x = None
|
||||
self.valid_y = None
|
||||
self.save_dev_data = None
|
||||
self.save_output = test_args["save_output"]
|
||||
self.output = None
|
||||
self.save_loss = test_args["save_loss"]
|
||||
@ -26,8 +26,14 @@ class BaseTester(Action):
|
||||
self.pickle_path = test_args["pickle_path"]
|
||||
self.iterator = None
|
||||
|
||||
self.model = None
|
||||
self.eval_history = []
|
||||
|
||||
def test(self, network):
|
||||
# print("--------------testing----------------")
|
||||
self.model = network
|
||||
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
|
||||
dev_data = self.prepare_input(self.pickle_path)
|
||||
@ -35,7 +41,6 @@ class BaseTester(Action):
|
||||
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))
|
||||
|
||||
batch_output = list()
|
||||
eval_history = list()
|
||||
num_iter = len(dev_data) // self.batch_size
|
||||
|
||||
for step in range(num_iter):
|
||||
@ -47,11 +52,18 @@ class BaseTester(Action):
|
||||
if self.save_output:
|
||||
batch_output.append(prediction)
|
||||
if self.save_loss:
|
||||
eval_history.append(eval_results)
|
||||
self.eval_history.append(eval_results)
|
||||
|
||||
def prepare_input(self, data_path):
|
||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
|
||||
return data_dev
|
||||
"""
|
||||
Save the dev data once it is loaded. Can return directly next time.
|
||||
:param data_path: str, the path to the pickle data for dev
|
||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
|
||||
"""
|
||||
if self.save_dev_data is None:
|
||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
|
||||
self.save_dev_data = data_dev
|
||||
return self.save_dev_data
|
||||
|
||||
def batchify(self, data):
|
||||
"""
|
||||
@ -99,11 +111,12 @@ class BaseTester(Action):
|
||||
raise NotImplementedError
|
||||
|
||||
def mode(self, model, test=True):
|
||||
"""To do: combine this function with Trainer"""
|
||||
"""To do: combine this function with Trainer ?? """
|
||||
if test:
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
self.eval_history.clear()
|
||||
|
||||
|
||||
class POSTester(BaseTester):
|
||||
@ -115,6 +128,7 @@ class POSTester(BaseTester):
|
||||
super(POSTester, self).__init__(test_args)
|
||||
self.max_len = None
|
||||
self.mask = None
|
||||
self.batch_result = None
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""To Do: combine with Trainer
|
||||
@ -132,5 +146,9 @@ class POSTester(BaseTester):
|
||||
return y
|
||||
|
||||
def evaluate(self, predict, truth):
|
||||
"""To Do: """
|
||||
return 0
|
||||
truth = torch.Tensor(truth)
|
||||
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len)
|
||||
return loss.data
|
||||
|
||||
def matrices(self):
|
||||
return np.mean(self.eval_history)
|
||||
|
@ -89,6 +89,7 @@ class BaseTrainer(Action):
|
||||
if data_dev is None:
|
||||
raise RuntimeError("No validation data provided.")
|
||||
validator.test(network)
|
||||
print("[epoch {}] dev loss={:.2f}".format(epoch, validator.matrices()))
|
||||
|
||||
# finish training
|
||||
|
||||
@ -386,6 +387,7 @@ class POSTrainer(BaseTrainer):
|
||||
else:
|
||||
self.define_loss()
|
||||
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
|
||||
# print("loss={:.2f}".format(loss.data))
|
||||
return loss
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user