mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
self-attention module
This commit is contained in:
parent
b46c4ba042
commit
8b859dc7b6
@ -5,6 +5,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Self Attention Module.
|
||||
@ -15,57 +17,53 @@ class SelfAttention(nn.Module):
|
||||
num_vec: int, the number of encoded vectors
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None):
|
||||
def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None,
|
||||
use_cuda=False):
|
||||
super(SelfAttention, self).__init__()
|
||||
# self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
|
||||
# self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
|
||||
self.attention_hops = num_vec
|
||||
|
||||
self.ws1 = nn.Linear(input_size, dim, bias=False)
|
||||
self.ws2 = nn.Linear(dim, num_vec, bias=False)
|
||||
self.attention_hops = attention_hops
|
||||
self.ws1 = nn.Linear(input_size, attention_unit, bias=False)
|
||||
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False)
|
||||
if use_cuda:
|
||||
self.I = Variable(torch.eye(attention_hops).cuda(), requires_grad=False)
|
||||
else:
|
||||
self.I = Variable(torch.eye(attention_hops), requires_grad=False)
|
||||
self.I_origin = self.I
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.tanh = nn.Tanh()
|
||||
initial_parameter(self, initial_method)
|
||||
def penalization(self, A):
|
||||
|
||||
def penalization(self, attention):
|
||||
"""
|
||||
compute the penalization term for attention module
|
||||
"""
|
||||
if self.W_s1.is_cuda:
|
||||
I = Variable(torch.eye(A.size(1)).cuda(), requires_grad=False)
|
||||
else:
|
||||
I = Variable(torch.eye(A.size(1)), requires_grad=False)
|
||||
M = torch.matmul(A, torch.transpose(A, 1, 2)) - I
|
||||
M = M.view(M.size(0), -1)
|
||||
return torch.sum(M ** 2, dim=1)
|
||||
baz = attention.size(0)
|
||||
size = self.I.size()
|
||||
if len(size) != 3 or size[0] != baz:
|
||||
self.I = self.I_origin.expand(baz, -1, -1)
|
||||
attentionT = torch.transpose(attention, 1, 2).contiguous()
|
||||
mat = torch.bmm(attention, attentionT) - self.I[:attention.size(0)]
|
||||
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5
|
||||
return torch.sum(ret) / size[0]
|
||||
|
||||
def forward(self, outp ,inp):
|
||||
# the following code can not be use because some word are padding ,these is not such module!
|
||||
|
||||
# inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # []
|
||||
# A = self.softmax(torch.matmul(self.W_s2, inter))
|
||||
# out = torch.matmul(A, x)
|
||||
# out = out.view(out.size(0), -1)
|
||||
# penalty = self.penalization(A)
|
||||
# return out, penalty
|
||||
outp = outp.contiguous()
|
||||
size = outp.size() # [bsz, len, nhid]
|
||||
|
||||
compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
|
||||
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
|
||||
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
|
||||
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
|
||||
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]
|
||||
|
||||
hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
|
||||
attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
|
||||
attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len]
|
||||
penalized_alphas = attention + (
|
||||
-10000 * (concatenated_inp == 0).float())
|
||||
# [bsz, hop, len] + [bsz, hop, len]
|
||||
attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
|
||||
attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
|
||||
return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid]
|
||||
def forward(self, input, input_origin):
|
||||
"""
|
||||
:param input: the matrix to do attention. [baz, senLen, h_dim]
|
||||
:param inp: then token index include pad token( 0 ) [baz , senLen]
|
||||
:return output1: the input matrix after attention operation [baz, multi-head , h_dim]
|
||||
:return output2: the attention penalty term, a scalar [1]
|
||||
"""
|
||||
input = input.contiguous()
|
||||
size = input.size() # [bsz, len, nhid]
|
||||
|
||||
|
||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
|
||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
|
||||
|
||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
|
||||
attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
|
||||
|
||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
|
||||
attention = F.softmax(attention,2) # [baz ,hop, len]
|
||||
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user