mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
sequence labeling ready to Train!
This commit is contained in:
parent
83c032df5d
commit
c98d5924b5
@ -1,5 +1,4 @@
|
||||
import _pickle
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -22,18 +21,22 @@ class BaseTrainer(Action):
|
||||
- grad_backward
|
||||
- get_loss
|
||||
"""
|
||||
TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"])
|
||||
|
||||
def __init__(self, train_args):
|
||||
"""
|
||||
training parameters
|
||||
:param train_args: dict of (key, value)
|
||||
|
||||
The base trainer requires the following keys:
|
||||
- epochs: int, the number of epochs in training
|
||||
- validate: bool, whether or not to validate on dev set
|
||||
- batch_size: int
|
||||
- pickle_path: str, the path to pickle files for pre-processing
|
||||
"""
|
||||
super(BaseTrainer, self).__init__()
|
||||
self.train_args = train_args
|
||||
self.n_epochs = train_args.epochs
|
||||
# self.validate = train_args.validate
|
||||
self.batch_size = train_args.batch_size
|
||||
self.pickle_path = train_args.pickle_path
|
||||
self.n_epochs = train_args["epochs"]
|
||||
self.validate = train_args["validate"]
|
||||
self.batch_size = train_args["batch_size"]
|
||||
self.pickle_path = train_args["pickle_path"]
|
||||
self.model = None
|
||||
self.iterator = None
|
||||
self.loss_func = None
|
||||
@ -66,8 +69,9 @@ class BaseTrainer(Action):
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
self.mode(test=False)
|
||||
|
||||
self.define_optimizer()
|
||||
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))
|
||||
|
||||
for step in range(iterations):
|
||||
batch_x, batch_y = self.batchify(self.batch_size, data_train)
|
||||
|
||||
@ -173,8 +177,6 @@ class BaseTrainer(Action):
|
||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
|
||||
"""
|
||||
if self.iterator is None:
|
||||
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
|
||||
indices = next(self.iterator)
|
||||
batch = [data[idx] for idx in indices]
|
||||
batch_x = [sample[0] for sample in batch]
|
||||
@ -304,6 +306,7 @@ class WordSegTrainer(BaseTrainer):
|
||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85)
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
truth = torch.Tensor(truth)
|
||||
self._loss = torch.nn.CrossEntropyLoss(predict, truth)
|
||||
return self._loss
|
||||
|
||||
@ -316,13 +319,16 @@ class WordSegTrainer(BaseTrainer):
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
class POSTrainer(BaseTrainer):
|
||||
TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"])
|
||||
|
||||
class POSTrainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for Sequence Modeling
|
||||
|
||||
"""
|
||||
def __init__(self, train_args):
|
||||
super(POSTrainer, self).__init__(train_args)
|
||||
self.vocab_size = train_args.vocab_size
|
||||
self.num_classes = train_args.num_classes
|
||||
self.vocab_size = train_args["vocab_size"]
|
||||
self.num_classes = train_args["num_classes"]
|
||||
self.max_len = None
|
||||
self.mask = None
|
||||
|
||||
@ -357,6 +363,13 @@ class POSTrainer(BaseTrainer):
|
||||
def define_optimizer(self):
|
||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
|
||||
|
||||
def grad_backward(self, loss):
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
def update(self):
|
||||
self.optimizer.step()
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
"""
|
||||
Compute loss given prediction and ground truth.
|
||||
@ -364,16 +377,18 @@ class POSTrainer(BaseTrainer):
|
||||
:param truth: ground truth label vector, [batch_size, max_len]
|
||||
:return: a scalar
|
||||
"""
|
||||
truth = torch.Tensor(truth)
|
||||
if self.loss_func is None:
|
||||
if hasattr(self.model, "loss"):
|
||||
self.loss_func = self.model.loss
|
||||
else:
|
||||
self.define_loss()
|
||||
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
|
||||
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__name__":
|
||||
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./")
|
||||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
|
||||
trainer = BaseTrainer(train_args)
|
||||
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
|
||||
trainer.batchify(batch_size=3, data=data_train)
|
||||
|
@ -81,7 +81,7 @@ class SeqLabeling(BaseModel):
|
||||
x = x.float()
|
||||
y = y.long()
|
||||
mask = mask.byte()
|
||||
print(x.shape, y.shape, mask.shape)
|
||||
# print(x.shape, y.shape, mask.shape)
|
||||
|
||||
if self.use_crf:
|
||||
total_loss = self.crf(x, y, mask)
|
@ -1,3 +1,3 @@
|
||||
numpy==1.14.2
|
||||
numpy>=1.14.2
|
||||
torch==0.4.0
|
||||
torchvision==0.1.8
|
||||
torchvision>=0.1.8
|
||||
|
@ -5,7 +5,7 @@ sys.path.append("..")
|
||||
from fastNLP.action.trainer import POSTrainer
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||
from fastNLP.loader.preprocess import POSPreprocess
|
||||
from fastNLP.models.sequencce_modeling import SeqLabeling
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
data_name = "people.txt"
|
||||
data_path = "data_for_tests/people.txt"
|
||||
@ -22,13 +22,14 @@ if __name__ == "__main__":
|
||||
num_classes = p.num_classes
|
||||
|
||||
# Trainer
|
||||
train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes,
|
||||
vocab_size=vocab_size, pickle_path=pickle_path)
|
||||
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes,
|
||||
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": False}
|
||||
trainer = POSTrainer(train_args)
|
||||
|
||||
# Model
|
||||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
|
||||
|
||||
# Start training.
|
||||
# Start training
|
||||
trainer.train(model)
|
||||
|
||||
print("Training finished!")
|
||||
|
Loading…
Reference in New Issue
Block a user