mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
Combine make_batch for Trainer and Tester
- change parameter <seq_length-->mask> in loss function defined in seq model - Trainer & Tester have Action as default parameter, shared static methods like make_batch - add seq_len in make_batch of Inference - add SeqLabelInfer, a subclass of Inference - seq_labeling.py works
This commit is contained in:
parent
2c0079f3d5
commit
83f69b0e0f
@ -4,20 +4,16 @@
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import _pickle
|
||||
|
||||
|
||||
class Action(object):
|
||||
"""
|
||||
Operations shared by Trainer, Tester, and Inference.
|
||||
This is designed for reducing replicate codes.
|
||||
- prepare_input: data preparation before a forward pass.
|
||||
- make_batch: produce a min-batch of data. @staticmethod
|
||||
- pad: padding method used in sequence modeling. @staticmethod
|
||||
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
|
||||
- data_forward: a forward pass of the network.
|
||||
The base Action shall define operations shared by as much task-specific Actions as possible.
|
||||
"""
|
||||
|
||||
@ -83,47 +79,6 @@ class Action(object):
|
||||
else:
|
||||
model.train()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""
|
||||
Forward pass of the data.
|
||||
:param network: a model
|
||||
:param x: input feature matrix and label vector
|
||||
:return: output by the models
|
||||
|
||||
For PyTorch, just do "network(*x)"
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SeqLabelAction(Action):
|
||||
def __init__(self, action_args):
|
||||
"""
|
||||
Define task-specific member variables.
|
||||
:param action_args:
|
||||
"""
|
||||
super(SeqLabelAction, self).__init__()
|
||||
self.max_len = None
|
||||
self.mask = None
|
||||
self.best_accuracy = 0.0
|
||||
self.use_cuda = action_args["use_cuda"]
|
||||
self.seq_len = None
|
||||
self.batch_size = None
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
# unpack the returned value from make_batch
|
||||
if isinstance(inputs, tuple):
|
||||
x = inputs[0]
|
||||
self.seq_len = inputs[1]
|
||||
else:
|
||||
x = inputs
|
||||
x = torch.Tensor(x).long()
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
x = x.cuda()
|
||||
self.batch_size = x.size(0)
|
||||
self.max_len = x.size(1)
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
|
||||
def k_means_1d(x, k, max_iter=100):
|
||||
"""
|
||||
|
@ -1,7 +1,9 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.action import Batchifier, SequentialSampler
|
||||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
|
||||
from fastNLP.modules import utils
|
||||
|
||||
|
||||
class Inference(object):
|
||||
@ -32,13 +34,14 @@ class Inference(object):
|
||||
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
self.batch_output.clear()
|
||||
|
||||
self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
|
||||
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
|
||||
|
||||
num_iter = len(data) // self.batch_size
|
||||
|
||||
for step in range(num_iter):
|
||||
batch_x = self.make_batch(data)
|
||||
batch_x = self.make_batch(iterator, data)
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
@ -54,26 +57,18 @@ class Inference(object):
|
||||
self.batch_output.clear()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""
|
||||
This is only for sequence labeling with CRF decoder. TODO: more general ?
|
||||
:param network:
|
||||
:param x:
|
||||
:return:
|
||||
"""
|
||||
seq_len = [len(seq) for seq in x]
|
||||
x = torch.Tensor(x).long()
|
||||
y = network(x)
|
||||
prediction = network.prediction(y, seq_len)
|
||||
# To do: hide framework
|
||||
results = torch.Tensor(prediction).view(-1, )
|
||||
return list(results.data)
|
||||
raise NotImplementedError
|
||||
|
||||
def make_batch(self, data):
|
||||
indices = next(self.iterator)
|
||||
@staticmethod
|
||||
def make_batch(iterator, data, output_length=True):
|
||||
indices = next(iterator)
|
||||
batch_x = [data[idx] for idx in indices]
|
||||
if self.batch_size > 1:
|
||||
batch_x = self.pad(batch_x)
|
||||
return batch_x
|
||||
batch_x_pad = Inference.pad(batch_x)
|
||||
if output_length:
|
||||
seq_len = [len(x) for x in batch_x]
|
||||
return [batch_x_pad, seq_len]
|
||||
else:
|
||||
return batch_x_pad
|
||||
|
||||
@staticmethod
|
||||
def pad(batch, fill=0):
|
||||
@ -86,7 +81,7 @@ class Inference(object):
|
||||
max_length = max([len(x) for x in batch])
|
||||
for idx, sample in enumerate(batch):
|
||||
if len(sample) < max_length:
|
||||
batch[idx] = sample + [fill * (max_length - len(sample))]
|
||||
batch[idx] = sample + ([fill] * (max_length - len(sample)))
|
||||
return batch
|
||||
|
||||
def prepare_input(self, data):
|
||||
@ -109,10 +104,39 @@ class Inference(object):
|
||||
def prepare_output(self, batch_outputs):
|
||||
"""
|
||||
Transform list of batch outputs into strings.
|
||||
:param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor.
|
||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
|
||||
:return:
|
||||
"""
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
results.append([self.index2label[int(x.data)] for x in batch])
|
||||
for example in np.array(batch):
|
||||
results.append([self.index2label[int(x)] for x in example])
|
||||
return results
|
||||
|
||||
|
||||
class SeqLabelInfer(Inference):
|
||||
"""
|
||||
Inference on sequence labeling models.
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
super(SeqLabelInfer, self).__init__(pickle_path)
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
"""
|
||||
This is only for sequence labeling with CRF decoder.
|
||||
:param network:
|
||||
:param inputs:
|
||||
:return: Tensor
|
||||
"""
|
||||
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
|
||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
x = torch.Tensor(x).long()
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
y = network(x)
|
||||
prediction = network.prediction(y, mask)
|
||||
return torch.Tensor(prediction)
|
||||
|
@ -6,17 +6,18 @@ import torch
|
||||
|
||||
from fastNLP.core.action import Action
|
||||
from fastNLP.core.action import RandomSampler, Batchifier
|
||||
from fastNLP.modules import utils
|
||||
|
||||
|
||||
class BaseTester(Action):
|
||||
"""docstring for Tester"""
|
||||
|
||||
def __init__(self, test_args, action):
|
||||
def __init__(self, test_args, action=None):
|
||||
"""
|
||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
|
||||
"""
|
||||
super(BaseTester, self).__init__()
|
||||
self.action = action
|
||||
self.action = action if action is not None else Action()
|
||||
self.validate_in_training = test_args["validate_in_training"]
|
||||
self.save_dev_data = None
|
||||
self.save_output = test_args["save_output"]
|
||||
@ -52,7 +53,7 @@ class BaseTester(Action):
|
||||
for step in range(num_iter):
|
||||
batch_x, batch_y = self.action.make_batch(iterator, dev_data)
|
||||
|
||||
prediction = self.action.data_forward(network, batch_x)
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
eval_results = self.evaluate(prediction, batch_y)
|
||||
|
||||
@ -72,6 +73,9 @@ class BaseTester(Action):
|
||||
self.save_dev_data = data_dev
|
||||
return self.save_dev_data
|
||||
|
||||
def data_forward(self, network, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self, predict, truth):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -92,7 +96,7 @@ class POSTester(BaseTester):
|
||||
Tester for sequence labeling.
|
||||
"""
|
||||
|
||||
def __init__(self, test_args, action):
|
||||
def __init__(self, test_args, action=None):
|
||||
"""
|
||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
|
||||
"""
|
||||
@ -101,17 +105,37 @@ class POSTester(BaseTester):
|
||||
self.mask = None
|
||||
self.batch_result = None
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
if not isinstance(inputs, tuple):
|
||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
x = torch.Tensor(x).long()
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
x = x.cuda()
|
||||
mask = mask.cuda()
|
||||
self.mask = mask
|
||||
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
def evaluate(self, predict, truth):
|
||||
truth = torch.Tensor(truth)
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
truth = truth.cuda()
|
||||
loss = self.model.loss(predict, truth, self.action.seq_len) / self.batch_size
|
||||
prediction = self.model.prediction(predict, self.action.seq_len)
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
loss = self.model.loss(predict, truth, self.mask) / batch_size
|
||||
|
||||
prediction = self.model.prediction(predict, self.mask)
|
||||
results = torch.Tensor(prediction).view(-1,)
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
results = results.cuda()
|
||||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
|
||||
return [loss.data, accuracy]
|
||||
# make sure "results" is in the same device as "truth"
|
||||
results = results.to(truth)
|
||||
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
|
||||
return [loss.data, accuracy.data]
|
||||
|
||||
def metrics(self):
|
||||
batch_loss = np.mean([x[0] for x in self.eval_history])
|
||||
|
@ -8,8 +8,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.core.action import Action
|
||||
from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler
|
||||
from fastNLP.core.action import RandomSampler, Batchifier
|
||||
from fastNLP.core.tester import POSTester
|
||||
from fastNLP.modules import utils
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
|
||||
@ -23,10 +24,10 @@ class BaseTrainer(Action):
|
||||
- get_loss
|
||||
"""
|
||||
|
||||
def __init__(self, train_args, action):
|
||||
def __init__(self, train_args, action=None):
|
||||
"""
|
||||
:param train_args: dict of (key, value), or dict-like object. key is str.
|
||||
:param action: an Action object that wrap most operations shared by Trainer, Tester, and Inference.
|
||||
:param action: (optional) an Action object that wrap most operations shared by Trainer, Tester, and Inference.
|
||||
|
||||
The base trainer requires the following keys:
|
||||
- epochs: int, the number of epochs in training
|
||||
@ -35,7 +36,7 @@ class BaseTrainer(Action):
|
||||
- pickle_path: str, the path to pickle files for pre-processing
|
||||
"""
|
||||
super(BaseTrainer, self).__init__()
|
||||
self.action = action
|
||||
self.action = action if action is not None else Action()
|
||||
self.n_epochs = train_args["epochs"]
|
||||
self.batch_size = train_args["batch_size"]
|
||||
self.pickle_path = train_args["pickle_path"]
|
||||
@ -94,7 +95,7 @@ class BaseTrainer(Action):
|
||||
for step in range(iterations):
|
||||
batch_x, batch_y = self.action.make_batch(iterator, data_train)
|
||||
|
||||
prediction = self.action.data_forward(network, batch_x)
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
loss = self.get_loss(prediction, batch_y)
|
||||
self.grad_backward(loss)
|
||||
@ -137,6 +138,9 @@ class BaseTrainer(Action):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def data_forward(self, network, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def grad_backward(self, loss):
|
||||
"""
|
||||
Compute gradient with link rules.
|
||||
@ -223,7 +227,8 @@ class POSTrainer(BaseTrainer):
|
||||
Trainer for Sequence Modeling
|
||||
|
||||
"""
|
||||
def __init__(self, train_args, action):
|
||||
|
||||
def __init__(self, train_args, action=None):
|
||||
super(POSTrainer, self).__init__(train_args, action)
|
||||
self.vocab_size = train_args["vocab_size"]
|
||||
self.num_classes = train_args["num_classes"]
|
||||
@ -241,6 +246,24 @@ class POSTrainer(BaseTrainer):
|
||||
def update(self):
|
||||
self.optimizer.step()
|
||||
|
||||
def data_forward(self, network, inputs):
|
||||
if not isinstance(inputs, tuple):
|
||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
|
||||
# unpack the returned value from make_batch
|
||||
x, seq_len = inputs[0], inputs[1]
|
||||
batch_size, max_len = x.size(0), x.size(1)
|
||||
mask = utils.seq_mask(seq_len, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
|
||||
x = torch.Tensor(x).long()
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
x = x.cuda()
|
||||
mask = mask.cuda()
|
||||
self.mask = mask
|
||||
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
"""
|
||||
Compute loss given prediction and ground truth.
|
||||
@ -251,13 +274,10 @@ class POSTrainer(BaseTrainer):
|
||||
truth = torch.Tensor(truth)
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
truth = truth.cuda()
|
||||
assert truth.shape == (self.batch_size, self.action.max_len)
|
||||
if self.loss_func is None:
|
||||
if hasattr(self.model, "loss"):
|
||||
self.loss_func = self.model.loss
|
||||
else:
|
||||
self.define_loss()
|
||||
loss = self.loss_func(predict, truth, self.action.seq_len)
|
||||
batch_size, max_len = predict.size(0), predict.size(1)
|
||||
assert truth.shape == (batch_size, max_len)
|
||||
|
||||
loss = self.model.loss(predict, truth, self.mask)
|
||||
return loss
|
||||
|
||||
def best_eval_result(self, validator):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import decoder, encoder, utils
|
||||
from fastNLP.modules import decoder, encoder
|
||||
|
||||
|
||||
class SeqLabeling(BaseModel):
|
||||
@ -34,46 +34,25 @@ class SeqLabeling(BaseModel):
|
||||
# [batch_size, max_len, num_classes]
|
||||
return x
|
||||
|
||||
def loss(self, x, y, seq_length):
|
||||
def loss(self, x, y, mask):
|
||||
"""
|
||||
Negative log likelihood loss.
|
||||
:param x: FloatTensor, [batch_size, max_len, tag_size]
|
||||
:param y: LongTensor, [batch_size, max_len]
|
||||
:param seq_length: list of int. [batch_size]
|
||||
:param x: Tensor, [batch_size, max_len, tag_size]
|
||||
:param y: Tensor, [batch_size, max_len]
|
||||
:param mask: ByteTensor, [batch_size, ,max_len]
|
||||
:return loss: a scalar Tensor
|
||||
|
||||
"""
|
||||
x = x.float()
|
||||
y = y.long()
|
||||
|
||||
batch_size = x.size(0)
|
||||
max_len = x.size(1)
|
||||
|
||||
mask = utils.seq_mask(seq_length, max_len)
|
||||
mask = mask.byte().view(batch_size, max_len)
|
||||
|
||||
# TODO: remove
|
||||
if torch.cuda.is_available():
|
||||
mask = mask.cuda()
|
||||
# mask = x.new(batch_size, max_len)
|
||||
|
||||
total_loss = self.Crf(x, y, mask)
|
||||
|
||||
return torch.mean(total_loss)
|
||||
|
||||
def prediction(self, x, seq_length):
|
||||
def prediction(self, x, mask):
|
||||
"""
|
||||
:param x: FloatTensor, [batch_size, max_len, tag_size]
|
||||
:param seq_length: int
|
||||
:return prediction: list of tuple of (decode path(list), best score)
|
||||
:param mask: ByteTensor, [batch_size, max_len]
|
||||
:return prediction: list of [decode path(list)]
|
||||
"""
|
||||
x = x.float()
|
||||
max_len = x.size(1)
|
||||
|
||||
mask = utils.seq_mask(seq_length, max_len)
|
||||
# hack: make sure mask has the same device as x
|
||||
mask = mask.to(x).byte()
|
||||
|
||||
tag_seq = self.Crf.viterbi_decode(x, mask)
|
||||
|
||||
return tag_seq
|
||||
|
@ -132,6 +132,7 @@ class ConditionalRandomField(nn.Module):
|
||||
Given a feats matrix, return best decode path and best score.
|
||||
:param feats:
|
||||
:param masks:
|
||||
:param get_score: bool, whether to output the decode score.
|
||||
:return:List[Tuple(List, float)],
|
||||
"""
|
||||
batch_size, max_len, tag_size = feats.size()
|
||||
|
@ -2,7 +2,6 @@ import sys
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
from fastNLP.core.action import SeqLabelAction
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import POSTrainer
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
|
||||
@ -11,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import POSTester
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.inference import Inference
|
||||
from fastNLP.core.inference import SeqLabelInfer
|
||||
|
||||
data_name = "people.txt"
|
||||
data_path = "data_for_tests/people.txt"
|
||||
@ -51,10 +50,11 @@ def infer():
|
||||
"""
|
||||
|
||||
# Inference interface
|
||||
infer = Inference(pickle_path)
|
||||
infer = SeqLabelInfer(pickle_path)
|
||||
results = infer.predict(model, infer_data)
|
||||
|
||||
print(results)
|
||||
for res in results:
|
||||
print(res)
|
||||
print("Inference finished!")
|
||||
|
||||
|
||||
@ -72,10 +72,8 @@ def train_and_test():
|
||||
train_args["vocab_size"] = p.vocab_size
|
||||
train_args["num_classes"] = p.num_classes
|
||||
|
||||
action = SeqLabelAction(train_args)
|
||||
|
||||
# Trainer
|
||||
trainer = POSTrainer(train_args, action)
|
||||
trainer = POSTrainer(train_args)
|
||||
|
||||
# Model
|
||||
model = SeqLabeling(train_args)
|
||||
@ -103,7 +101,7 @@ def train_and_test():
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
|
||||
# Tester
|
||||
tester = POSTester(test_args, action)
|
||||
tester = POSTester(test_args)
|
||||
|
||||
# Start testing
|
||||
tester.test(model)
|
||||
@ -114,5 +112,5 @@ def train_and_test():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_and_test()
|
||||
|
||||
# train_and_test()
|
||||
infer()
|
||||
|
Loading…
Reference in New Issue
Block a user