mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
Merge pull request #74 from 2017alan/master
Add weight initialization for models.
This commit is contained in:
commit
47772a88be
@ -37,5 +37,7 @@ class Loss(object):
|
||||
"""
|
||||
if loss_name == "cross_entropy":
|
||||
return torch.nn.CrossEntropyLoss()
|
||||
elif loss_name == 'nll':
|
||||
return torch.nn.NLLLoss()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -297,7 +297,7 @@ class ClassPreprocess(BasePreprocess):
|
||||
|
||||
# build vocabulary from scratch if nothing exists
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
label2index = {} # DEFAULT_WORD_TO_INDEX.copy()
|
||||
|
||||
# collect every word and label
|
||||
for sent, label in data:
|
||||
|
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Self Attention Module.
|
||||
@ -13,13 +15,18 @@ class SelfAttention(nn.Module):
|
||||
num_vec: int, the number of encoded vectors
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, dim=10, num_vec=10):
|
||||
def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
|
||||
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
|
||||
# self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
|
||||
# self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
|
||||
self.attention_hops = num_vec
|
||||
|
||||
self.ws1 = nn.Linear(input_size, dim, bias=False)
|
||||
self.ws2 = nn.Linear(dim, num_vec, bias=False)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def penalization(self, A):
|
||||
"""
|
||||
compute the penalization term for attention module
|
||||
@ -32,11 +39,33 @@ class SelfAttention(nn.Module):
|
||||
M = M.view(M.size(0), -1)
|
||||
return torch.sum(M ** 2, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2)))
|
||||
A = self.softmax(torch.matmul(self.W_s2, inter))
|
||||
out = torch.matmul(A, x)
|
||||
out = out.view(out.size(0), -1)
|
||||
penalty = self.penalization(A)
|
||||
return out, penalty
|
||||
def forward(self, outp ,inp):
|
||||
# the following code can not be use because some word are padding ,these is not such module!
|
||||
|
||||
# inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # []
|
||||
# A = self.softmax(torch.matmul(self.W_s2, inter))
|
||||
# out = torch.matmul(A, x)
|
||||
# out = out.view(out.size(0), -1)
|
||||
# penalty = self.penalization(A)
|
||||
# return out, penalty
|
||||
outp = outp.contiguous()
|
||||
size = outp.size() # [bsz, len, nhid]
|
||||
|
||||
compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
|
||||
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
|
||||
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
|
||||
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
|
||||
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]
|
||||
|
||||
hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
|
||||
attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
|
||||
attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len]
|
||||
penalized_alphas = attention + (
|
||||
-10000 * (concatenated_inp == 0).float())
|
||||
# [bsz, hop, len] + [bsz, hop, len]
|
||||
attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
|
||||
attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
|
||||
return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid]
|
||||
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
def log_sum_exp(x, dim=-1):
|
||||
max_value, _ = x.max(dim=dim, keepdim=True)
|
||||
@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens):
|
||||
|
||||
|
||||
class ConditionalRandomField(nn.Module):
|
||||
def __init__(self, tag_size, include_start_end_trans=True):
|
||||
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None):
|
||||
"""
|
||||
:param tag_size: int, num of tags
|
||||
:param include_start_end_trans: bool, whether to include start/end tag
|
||||
@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module):
|
||||
self.start_scores = nn.Parameter(torch.randn(tag_size))
|
||||
self.end_scores = nn.Parameter(torch.randn(tag_size))
|
||||
|
||||
self.reset_parameter()
|
||||
|
||||
# self.reset_parameter()
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameter(self):
|
||||
nn.init.xavier_normal_(self.transition_m)
|
||||
if self.include_start_end_trans:
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, size_layer, num_class=2, activation='relu'):
|
||||
def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None):
|
||||
"""Multilayer Perceptrons as a decoder
|
||||
|
||||
Args:
|
||||
@ -36,7 +36,7 @@ class MLP(nn.Module):
|
||||
self.hidden_active = activation
|
||||
else:
|
||||
raise ValueError("should set activation correctly: {}".format(activation))
|
||||
|
||||
initial_parameter(self, initial_method )
|
||||
def forward(self, x):
|
||||
for layer in self.hiddens:
|
||||
x = self.hidden_active(layer(x))
|
||||
|
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
# from torch.nn.init import xavier_uniform
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class ConvCharEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
|
||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None):
|
||||
"""
|
||||
Character Level Word Embedding
|
||||
:param char_emb_size: the size of character level embedding. Default: 50
|
||||
@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module):
|
||||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
|
||||
for i in range(len(kernels))])
|
||||
|
||||
initial_parameter(self,initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: [batch_size * sent_length, word_length, char_emb_size]
|
||||
@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module):
|
||||
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
|
||||
"""
|
||||
|
||||
def __init__(self, char_emb_size=50, hidden_size=None):
|
||||
def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None):
|
||||
super(LSTMCharEmbedding, self).__init__()
|
||||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
|
||||
|
||||
@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module):
|
||||
num_layers=1,
|
||||
bias=True,
|
||||
batch_first=True)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x:[ n_batch*n_word, word_length, char_emb_size]
|
||||
|
@ -6,6 +6,7 @@ import torch.nn as nn
|
||||
from torch.nn.init import xavier_uniform_
|
||||
# import torch.nn.functional as F
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
class Conv(nn.Module):
|
||||
"""
|
||||
@ -15,7 +16,7 @@ class Conv(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size,
|
||||
stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, activation='relu'):
|
||||
groups=1, bias=True, activation='relu',initial_method = None ):
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
@ -26,7 +27,7 @@ class Conv(nn.Module):
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
xavier_uniform_(self.conv.weight)
|
||||
# xavier_uniform_(self.conv.weight)
|
||||
|
||||
activations = {
|
||||
'relu': nn.ReLU(),
|
||||
@ -37,6 +38,7 @@ class Conv(nn.Module):
|
||||
raise Exception(
|
||||
'Should choose activation function from: ' +
|
||||
', '.join([x for x in activations]))
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import xavier_uniform_
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
class ConvMaxpool(nn.Module):
|
||||
"""
|
||||
@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_sizes,
|
||||
stride=1, padding=0, dilation=1,
|
||||
groups=1, bias=True, activation='relu'):
|
||||
groups=1, bias=True, activation='relu',initial_method = None ):
|
||||
super(ConvMaxpool, self).__init__()
|
||||
|
||||
# convolution
|
||||
@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module):
|
||||
raise Exception(
|
||||
"Undefined activation function: choose from: relu")
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def forward(self, x):
|
||||
# [N,L,C] -> [N,C,L]
|
||||
x = torch.transpose(x, 1, 2)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class Linear(nn.Module):
|
||||
"""
|
||||
Linear module
|
||||
@ -12,10 +12,10 @@ class Linear(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, bias=True):
|
||||
def __init__(self, input_size, output_size, bias=True,initial_method = None ):
|
||||
super(Linear, self).__init__()
|
||||
self.linear = nn.Linear(input_size, output_size, bias)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
class Lstm(nn.Module):
|
||||
"""
|
||||
LSTM module
|
||||
@ -13,11 +13,13 @@ class Lstm(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False):
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None):
|
||||
super(Lstm, self).__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def forward(self, x):
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
if __name__ == "__main__":
|
||||
lstm = Lstm(10)
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
def MaskedRecurrent(reverse=False):
|
||||
def forward(input, hidden, cell, mask, train=True, dropout=0):
|
||||
"""
|
||||
@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False):
|
||||
class MaskedRNNBase(nn.Module):
|
||||
def __init__(self, Cell, input_size, hidden_size,
|
||||
num_layers=1, bias=True, batch_first=False,
|
||||
layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs):
|
||||
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs):
|
||||
"""
|
||||
:param Cell:
|
||||
:param input_size:
|
||||
@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module):
|
||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs)
|
||||
self.all_cells.append(cell)
|
||||
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for cell in self.all_cells:
|
||||
cell.reset_parameters()
|
||||
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
def default_initializer(hidden_size):
|
||||
stdv = 1.0 / math.sqrt(hidden_size)
|
||||
@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False):
|
||||
class VarMaskedRNNBase(nn.Module):
|
||||
def __init__(self, Cell, input_size, hidden_size,
|
||||
num_layers=1, bias=True, batch_first=False,
|
||||
dropout=(0, 0), bidirectional=False, initializer=None, **kwargs):
|
||||
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs):
|
||||
|
||||
super(VarMaskedRNNBase, self).__init__()
|
||||
self.Cell = Cell
|
||||
@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module):
|
||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs)
|
||||
self.all_cells.append(cell)
|
||||
self.add_module('cell%d' % (layer * num_directions + direction), cell)
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for cell in self.all_cells:
|
||||
cell.reset_parameters()
|
||||
@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase):
|
||||
\end{array}
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None):
|
||||
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None):
|
||||
super(VarFastLSTMCell, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase):
|
||||
self.p_hidden = p_hidden
|
||||
self.noise_in = None
|
||||
self.noise_hidden = None
|
||||
|
||||
initial_parameter(self, initial_method)
|
||||
def reset_parameters(self):
|
||||
for weight in self.parameters():
|
||||
if weight.dim() == 1:
|
||||
|
@ -2,8 +2,8 @@ from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
def mask_softmax(matrix, mask):
|
||||
if mask is None:
|
||||
result = torch.nn.functional.softmax(matrix, dim=-1)
|
||||
@ -11,6 +11,51 @@ def mask_softmax(matrix, mask):
|
||||
raise NotImplementedError
|
||||
return result
|
||||
|
||||
def initial_parameter(net ,initial_method =None):
|
||||
|
||||
if initial_method == 'xavier_uniform':
|
||||
init_method = init.xavier_uniform_
|
||||
elif initial_method=='xavier_normal':
|
||||
init_method = init.xavier_normal_
|
||||
elif initial_method == 'kaiming_normal' or initial_method =='msra':
|
||||
init_method = init.kaiming_normal
|
||||
elif initial_method == 'kaiming_uniform':
|
||||
init_method = init.kaiming_normal
|
||||
elif initial_method == 'orthogonal':
|
||||
init_method = init.orthogonal_
|
||||
elif initial_method == 'sparse':
|
||||
init_method = init.sparse_
|
||||
elif initial_method =='normal':
|
||||
init_method = init.normal_
|
||||
elif initial_method =='uniform':
|
||||
initial_method = init.uniform_
|
||||
else:
|
||||
init_method = init.xavier_normal_
|
||||
def weights_init(m):
|
||||
# classname = m.__class__.__name__
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn
|
||||
if initial_method != None:
|
||||
init_method(m.weight.data)
|
||||
else:
|
||||
init.xavier_normal_(m.weight.data)
|
||||
init.normal_(m.bias.data)
|
||||
elif isinstance(m, nn.LSTM):
|
||||
for w in m.parameters():
|
||||
if len(w.data.size())>1:
|
||||
init_method(w.data) # weight
|
||||
else:
|
||||
init.normal_(w.data) # bias
|
||||
elif hasattr(m, 'weight') and m.weight.requires_grad:
|
||||
init_method(m.weight.data)
|
||||
else:
|
||||
for w in m.parameters() :
|
||||
if w.requires_grad:
|
||||
if len(w.data.size())>1:
|
||||
init_method(w.data) # weight
|
||||
else:
|
||||
init.normal_(w.data) # bias
|
||||
# print("init else")
|
||||
net.apply(weights_init)
|
||||
|
||||
def seq_mask(seq_len, max_len):
|
||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)]
|
||||
|
@ -0,0 +1,13 @@
|
||||
[train]
|
||||
epochs = 30
|
||||
batch_size = 32
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
save_best_dev = true
|
||||
model_saved_path = "./save/"
|
||||
rnn_hidden_units = 300
|
||||
word_emb_dim = 300
|
||||
use_crf = true
|
||||
use_cuda = false
|
||||
loss_func = "cross_entropy"
|
||||
num_classes = 5
|
80
reproduction/LSTM+self_attention_sentiment_analysis/main.py
Normal file
80
reproduction/LSTM+self_attention_sentiment_analysis/main.py
Normal file
@ -0,0 +1,80 @@
|
||||
|
||||
import os
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
|
||||
from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader
|
||||
from fastNLP.loader.config_loader import ConfigSection
|
||||
from fastNLP.loader.config_loader import ConfigLoader
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
|
||||
from fastNLP.core.preprocess import ClassPreprocess as Preprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
|
||||
from fastNLP.modules.encoder.embedding import Embedding as Embedding
|
||||
from fastNLP.modules.encoder.lstm import Lstm
|
||||
from fastNLP.modules.aggregation.self_attention import SelfAttention
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
|
||||
|
||||
train_data_path = 'small_train_data.txt'
|
||||
dev_data_path = 'small_dev_data.txt'
|
||||
# emb_path = 'glove.txt'
|
||||
|
||||
lstm_hidden_size = 300
|
||||
embeding_size = 300
|
||||
attention_unit = 350
|
||||
attention_hops = 10
|
||||
class_num = 5
|
||||
nfc = 3000
|
||||
### data load ###
|
||||
train_dataset = Dataset_loader(train_data_path)
|
||||
train_data = train_dataset.load()
|
||||
|
||||
dev_args = Dataset_loader(dev_data_path)
|
||||
dev_data = dev_args.load()
|
||||
|
||||
###### preprocess ####
|
||||
preprocess = Preprocess()
|
||||
word2index, label2index = preprocess.build_dict(train_data)
|
||||
train_data, dev_data = preprocess.run(train_data, dev_data)
|
||||
|
||||
|
||||
|
||||
# emb = EmbedLoader(emb_path)
|
||||
# embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index)
|
||||
### construct vocab ###
|
||||
|
||||
class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
|
||||
def __init__(self, args=None):
|
||||
super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
|
||||
self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
|
||||
self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True)
|
||||
self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
|
||||
self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,)
|
||||
def forward(self,x):
|
||||
x_emb = self.embedding(x)
|
||||
output = self.lstm(x_emb)
|
||||
after_attention, penalty = self.attention(output,x)
|
||||
after_attention =after_attention.view(after_attention.size(0),-1)
|
||||
output = self.mlp(after_attention)
|
||||
return output
|
||||
|
||||
def loss(self, predict, ground_truth):
|
||||
print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size())))
|
||||
print(ground_truth)
|
||||
return F.cross_entropy(predict, ground_truth)
|
||||
|
||||
train_args = ConfigSection()
|
||||
ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
|
||||
train_args['vocab'] = len(word2index)
|
||||
|
||||
|
||||
trainer = ClassificationTrainer(**train_args.data)
|
||||
|
||||
# for k in train_args.__dict__.keys():
|
||||
# print(k, train_args[k])
|
||||
model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
|
||||
trainer.train(model,train_data , dev_data)
|
8
setup.py
8
setup.py
@ -2,18 +2,18 @@
|
||||
# coding=utf-8
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
readme = f.read()
|
||||
|
||||
with open('LICENSE') as f:
|
||||
with open('LICENSE', encoding='utf-8') as f:
|
||||
license = f.read()
|
||||
|
||||
with open('requirements.txt') as f:
|
||||
with open('requirements.txt', encoding='utf-8') as f:
|
||||
reqs = f.read()
|
||||
|
||||
setup(
|
||||
name='fastNLP',
|
||||
version='0.0.1',
|
||||
version='0.0.3',
|
||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
|
||||
long_description=readme,
|
||||
license=license,
|
||||
|
Loading…
Reference in New Issue
Block a user