mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
refactor word_seg model & its test
This commit is contained in:
parent
83fe6f9f21
commit
1426fc3582
@ -2,7 +2,7 @@ from collections import namedtuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from fastNLP.action import Action
|
from fastNLP.action.action import Action
|
||||||
|
|
||||||
|
|
||||||
class Tester(Action):
|
class Tester(Action):
|
||||||
|
@ -111,7 +111,7 @@ class BaseTrainer(Action):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def data_forward(self, network, *x):
|
def data_forward(self, network, x):
|
||||||
"""
|
"""
|
||||||
Forward pass of the data.
|
Forward pass of the data.
|
||||||
:param network: a model
|
:param network: a model
|
||||||
@ -158,7 +158,7 @@ class ToyTrainer(BaseTrainer):
|
|||||||
def mode(self, test=False):
|
def mode(self, test=False):
|
||||||
self.model.mode(test)
|
self.model.mode(test)
|
||||||
|
|
||||||
def data_forward(self, network, *x):
|
def data_forward(self, network, x):
|
||||||
return np.matmul(x, self.weight) + self.bias
|
return np.matmul(x, self.weight) + self.bias
|
||||||
|
|
||||||
def grad_backward(self, loss):
|
def grad_backward(self, loss):
|
||||||
@ -175,6 +175,91 @@ class ToyTrainer(BaseTrainer):
|
|||||||
self._optimizer.step()
|
self._optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
class WordSegTrainer(BaseTrainer):
|
||||||
|
"""
|
||||||
|
reserve for changes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, train_args):
|
||||||
|
super(WordSegTrainer, self).__init__(train_args)
|
||||||
|
self.id2word = None
|
||||||
|
self.word2id = None
|
||||||
|
self.id2tag = None
|
||||||
|
self.tag2id = None
|
||||||
|
|
||||||
|
self.lstm_batch_size = 8
|
||||||
|
self.lstm_seq_len = 32 # Trainer batch_size == lstm_batch_size * lstm_seq_len
|
||||||
|
self.hidden_dim = 100
|
||||||
|
self.lstm_num_layers = 2
|
||||||
|
self.vocab_size = 100
|
||||||
|
self.word_emb_dim = 100
|
||||||
|
|
||||||
|
self.hidden = (self.to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)),
|
||||||
|
self.to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)))
|
||||||
|
|
||||||
|
self.optimizer = None
|
||||||
|
self._loss = None
|
||||||
|
|
||||||
|
self.USE_GPU = False
|
||||||
|
|
||||||
|
def to_var(self, x):
|
||||||
|
if torch.cuda.is_available() and self.USE_GPU:
|
||||||
|
x = x.cuda()
|
||||||
|
return torch.autograd.Variable(x)
|
||||||
|
|
||||||
|
def prepare_input(self, data):
|
||||||
|
"""
|
||||||
|
perform word indices lookup to convert strings into indices
|
||||||
|
:param data: list of string, each string contains word + space + [B, M, E, S]
|
||||||
|
:return
|
||||||
|
"""
|
||||||
|
word_list = []
|
||||||
|
tag_list = []
|
||||||
|
for line in data:
|
||||||
|
if len(line) > 2:
|
||||||
|
tokens = line.split("#")
|
||||||
|
word_list.append(tokens[0])
|
||||||
|
tag_list.append(tokens[2][0])
|
||||||
|
self.id2word = list(set(word_list))
|
||||||
|
self.word2id = {word: idx for idx, word in enumerate(self.id2word)}
|
||||||
|
self.id2tag = list(set(tag_list))
|
||||||
|
self.tag2id = {tag: idx for idx, tag in enumerate(self.id2tag)}
|
||||||
|
words = np.array([self.word2id[w] for w in word_list]).reshape(-1, 1)
|
||||||
|
tags = np.array([self.tag2id[t] for t in tag_list]).reshape(-1, 1)
|
||||||
|
return words, tags
|
||||||
|
|
||||||
|
def mode(self, test=False):
|
||||||
|
if test:
|
||||||
|
self.model.eval()
|
||||||
|
else:
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
def data_forward(self, network, x):
|
||||||
|
"""
|
||||||
|
:param network: a PyTorch model
|
||||||
|
:param x: sequence of length [batch_size], word indices
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
x = x.reshape(self.lstm_batch_size, self.lstm_seq_len)
|
||||||
|
output, self.hidden = network(x, self.hidden)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def define_optimizer(self):
|
||||||
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85)
|
||||||
|
|
||||||
|
def get_loss(self, predict, truth):
|
||||||
|
self._loss = torch.nn.CrossEntropyLoss(predict, truth)
|
||||||
|
return self._loss
|
||||||
|
|
||||||
|
def grad_backward(self, network):
|
||||||
|
self.model.zero_grad()
|
||||||
|
self._loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__name__":
|
if __name__ == "__name__":
|
||||||
Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step",
|
Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step",
|
||||||
"log_validation", "batch_size"])
|
"log_validation", "batch_size"])
|
||||||
|
@ -6,11 +6,16 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from model.base_model import BaseModel
|
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
from fastNLP.models.base_model import BaseModel
|
||||||
|
|
||||||
USE_GPU = True
|
USE_GPU = True
|
||||||
|
|
||||||
|
"""
|
||||||
|
To be deprecated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class CharLM(BaseModel):
|
class CharLM(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -1,95 +1,6 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
from fastNLP.models.base_model import BaseModel, BaseController
|
from fastNLP.models.base_model import BaseModel
|
||||||
|
|
||||||
USE_GPU = True
|
|
||||||
|
|
||||||
|
|
||||||
def to_var(x):
|
|
||||||
if torch.cuda.is_available() and USE_GPU:
|
|
||||||
x = x.cuda()
|
|
||||||
return Variable(x)
|
|
||||||
|
|
||||||
|
|
||||||
class WordSegModel(BaseController):
|
|
||||||
"""
|
|
||||||
Model controller for WordSeg
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(WordSegModel, self).__init__()
|
|
||||||
self.id2word = None
|
|
||||||
self.word2id = None
|
|
||||||
self.id2tag = None
|
|
||||||
self.tag2id = None
|
|
||||||
|
|
||||||
self.lstm_batch_size = 8
|
|
||||||
self.lstm_seq_len = 32 # Trainer batch_size == lstm_batch_size * lstm_seq_len
|
|
||||||
self.hidden_dim = 100
|
|
||||||
self.lstm_num_layers = 2
|
|
||||||
self.vocab_size = 100
|
|
||||||
self.word_emb_dim = 100
|
|
||||||
|
|
||||||
self.model = WordSeg(self.hidden_dim, self.lstm_num_layers, self.vocab_size, self.word_emb_dim)
|
|
||||||
self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)),
|
|
||||||
to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)))
|
|
||||||
|
|
||||||
self.optimizer = None
|
|
||||||
self._loss = None
|
|
||||||
|
|
||||||
def prepare_input(self, data):
|
|
||||||
"""
|
|
||||||
perform word indices lookup to convert strings into indices
|
|
||||||
:param data: list of string, each string contains word + space + [B, M, E, S]
|
|
||||||
:return
|
|
||||||
"""
|
|
||||||
word_list = []
|
|
||||||
tag_list = []
|
|
||||||
for line in data:
|
|
||||||
if len(line) > 2:
|
|
||||||
tokens = line.split("#")
|
|
||||||
word_list.append(tokens[0])
|
|
||||||
tag_list.append(tokens[2][0])
|
|
||||||
self.id2word = list(set(word_list))
|
|
||||||
self.word2id = {word: idx for idx, word in enumerate(self.id2word)}
|
|
||||||
self.id2tag = list(set(tag_list))
|
|
||||||
self.tag2id = {tag: idx for idx, tag in enumerate(self.id2tag)}
|
|
||||||
words = np.array([self.word2id[w] for w in word_list]).reshape(-1, 1)
|
|
||||||
tags = np.array([self.tag2id[t] for t in tag_list]).reshape(-1, 1)
|
|
||||||
return words, tags
|
|
||||||
|
|
||||||
def mode(self, test=False):
|
|
||||||
if test:
|
|
||||||
self.model.eval()
|
|
||||||
else:
|
|
||||||
self.model.train()
|
|
||||||
|
|
||||||
def data_forward(self, x):
|
|
||||||
"""
|
|
||||||
:param x: sequence of length [batch_size], word indices
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
x = x.reshape(self.lstm_batch_size, self.lstm_seq_len)
|
|
||||||
output, self.hidden = self.model(x, self.hidden)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def define_optimizer(self):
|
|
||||||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85)
|
|
||||||
|
|
||||||
def get_loss(self, pred, truth):
|
|
||||||
|
|
||||||
self._loss = nn.CrossEntropyLoss(pred, truth)
|
|
||||||
return self._loss
|
|
||||||
|
|
||||||
def grad_backward(self):
|
|
||||||
self.model.zero_grad()
|
|
||||||
self._loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
class WordSeg(BaseModel):
|
class WordSeg(BaseModel):
|
||||||
|
@ -1,23 +1,20 @@
|
|||||||
from loader.base_loader import BaseLoader
|
from fastNLP.action.tester import Tester
|
||||||
from model.word_seg_model import WordSegModel
|
from fastNLP.action.trainer import WordSegTrainer
|
||||||
|
from fastNLP.loader.base_loader import BaseLoader
|
||||||
from fastNLP.action import Tester
|
from fastNLP.models.word_seg_model import WordSeg
|
||||||
from fastNLP.action.trainer import Trainer
|
|
||||||
|
|
||||||
|
|
||||||
def test_charlm():
|
def test_wordseg():
|
||||||
train_config = Trainer.TrainConfig(epochs=5, validate=False, save_when_better=False,
|
train_config = WordSegTrainer.TrainConfig(epochs=5, validate=False, save_when_better=False,
|
||||||
log_per_step=10, log_validation=False, batch_size=254)
|
log_per_step=10, log_validation=False, batch_size=254)
|
||||||
trainer = Trainer(train_config)
|
trainer = WordSegTrainer(train_config)
|
||||||
|
|
||||||
model = WordSegModel()
|
model = WordSeg(100, 2, 1000)
|
||||||
|
|
||||||
train_data = BaseLoader("load_train", "./data_for_tests/cws_train").load_lines()
|
train_data = BaseLoader("load_train", "./data_for_tests/cws_train").load_lines()
|
||||||
|
|
||||||
trainer.train(model, train_data)
|
trainer.train(model, train_data)
|
||||||
|
|
||||||
trainer.save_model(model)
|
|
||||||
|
|
||||||
test_config = Tester.TestConfig(save_output=False, validate_in_training=False,
|
test_config = Tester.TestConfig(save_output=False, validate_in_training=False,
|
||||||
save_dev_input=False, save_loss=False, batch_size=254)
|
save_dev_input=False, save_loss=False, batch_size=254)
|
||||||
tester = Tester(test_config)
|
tester = Tester(test_config)
|
||||||
@ -28,4 +25,4 @@ def test_charlm():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_charlm()
|
test_wordseg()
|
||||||
|
Loading…
Reference in New Issue
Block a user