sequence labeling ready to Train!

This commit is contained in:
FengZiYjun 2018-07-10 20:46:35 +08:00
parent 83c032df5d
commit c98d5924b5
4 changed files with 40 additions and 24 deletions

View File

@ -1,5 +1,4 @@
import _pickle
from collections import namedtuple
import numpy as np
import torch
@ -22,18 +21,22 @@ class BaseTrainer(Action):
- grad_backward
- get_loss
"""
TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"])
def __init__(self, train_args):
"""
training parameters
:param train_args: dict of (key, value)
The base trainer requires the following keys:
- epochs: int, the number of epochs in training
- validate: bool, whether or not to validate on dev set
- batch_size: int
- pickle_path: str, the path to pickle files for pre-processing
"""
super(BaseTrainer, self).__init__()
self.train_args = train_args
self.n_epochs = train_args.epochs
# self.validate = train_args.validate
self.batch_size = train_args.batch_size
self.pickle_path = train_args.pickle_path
self.n_epochs = train_args["epochs"]
self.validate = train_args["validate"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]
self.model = None
self.iterator = None
self.loss_func = None
@ -66,8 +69,9 @@ class BaseTrainer(Action):
for epoch in range(self.n_epochs):
self.mode(test=False)
self.define_optimizer()
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))
for step in range(iterations):
batch_x, batch_y = self.batchify(self.batch_size, data_train)
@ -173,8 +177,6 @@ class BaseTrainer(Action):
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
if self.iterator is None:
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
@ -304,6 +306,7 @@ class WordSegTrainer(BaseTrainer):
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85)
def get_loss(self, predict, truth):
truth = torch.Tensor(truth)
self._loss = torch.nn.CrossEntropyLoss(predict, truth)
return self._loss
@ -316,13 +319,16 @@ class WordSegTrainer(BaseTrainer):
self.optimizer.step()
class POSTrainer(BaseTrainer):
TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"])
class POSTrainer(BaseTrainer):
"""
Trainer for Sequence Modeling
"""
def __init__(self, train_args):
super(POSTrainer, self).__init__(train_args)
self.vocab_size = train_args.vocab_size
self.num_classes = train_args.num_classes
self.vocab_size = train_args["vocab_size"]
self.num_classes = train_args["num_classes"]
self.max_len = None
self.mask = None
@ -357,6 +363,13 @@ class POSTrainer(BaseTrainer):
def define_optimizer(self):
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
def grad_backward(self, loss):
self.model.zero_grad()
loss.backward()
def update(self):
self.optimizer.step()
def get_loss(self, predict, truth):
"""
Compute loss given prediction and ground truth.
@ -364,16 +377,18 @@ class POSTrainer(BaseTrainer):
:param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar
"""
truth = torch.Tensor(truth)
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.define_loss()
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
return loss
if __name__ == "__name__":
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./")
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args)
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
trainer.batchify(batch_size=3, data=data_train)

View File

@ -81,7 +81,7 @@ class SeqLabeling(BaseModel):
x = x.float()
y = y.long()
mask = mask.byte()
print(x.shape, y.shape, mask.shape)
# print(x.shape, y.shape, mask.shape)
if self.use_crf:
total_loss = self.crf(x, y, mask)

View File

@ -1,3 +1,3 @@
numpy==1.14.2
numpy>=1.14.2
torch==0.4.0
torchvision==0.1.8
torchvision>=0.1.8

View File

@ -5,7 +5,7 @@ sys.path.append("..")
from fastNLP.action.trainer import POSTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
from fastNLP.models.sequencce_modeling import SeqLabeling
from fastNLP.models.sequence_modeling import SeqLabeling
data_name = "people.txt"
data_path = "data_for_tests/people.txt"
@ -22,13 +22,14 @@ if __name__ == "__main__":
num_classes = p.num_classes
# Trainer
train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes,
vocab_size=vocab_size, pickle_path=pickle_path)
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes,
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": False}
trainer = POSTrainer(train_args)
# Model
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
# Start training.
# Start training
trainer.train(model)
print("Training finished!")