add transformer

This commit is contained in:
yunfan 2018-10-20 10:54:41 +08:00
parent 102259df39
commit 830d223344
4 changed files with 81 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import torch
from torch import nn
import math
from fastNLP.modules.utils import mask_softmax
@ -17,3 +18,44 @@ class Attention(torch.nn.Module):
def _atten_forward(self, query, memory):
raise NotImplementedError
class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)
def forward(self, Q, K, V, seq_mask=None):
"""
:param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size]
:param seq_mask: [batch, seq_len]
"""
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if seq_mask is not None:
output.masked_fill_(seq_mask.lt(1), -float('inf'))
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)
class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):
out_feat = key_size if (i % 3) != 2 else value_size
self.in_linear.append(nn.Linear(input_size, out_feat))
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
self.out_linear = nn.Linear(value_size * num_atte, output_size)
def forward(self, Q, K, V, seq_mask=None):
heads = []
for i in range(len(self.attes)):
j = i * 3
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
headi = self.attes[i](qi, ki, vi, seq_mask)
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)

View File

@ -0,0 +1,32 @@
import torch
from torch import nn
import torch.nn.functional as F
from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization
class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)
def forward(self, input, seq_mask):
attention = self.atte(input)
norm_atte = self.norm1(attention + input)
output = self.ffn(norm_atte)
return self.norm2(output + norm_atte)
def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])
def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)

View File

@ -31,12 +31,12 @@ class GroupNorm(nn.Module):
class LayerNormalization(nn.Module):
""" Layer normalization module """
def __init__(self, d_hid, eps=1e-3):
def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__()
self.eps = eps
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True))
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True))
def forward(self, z):
if z.size(1) == 1:
@ -44,9 +44,8 @@ class LayerNormalization(nn.Module):
mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
ln_out = (z - mu) / (sigma + self.eps)
ln_out = ln_out * self.a_2 + self.b_2
return ln_out

View File

@ -1,5 +1,5 @@
[train]
epochs = 50
epochs = -1
batch_size = 16
pickle_path = "./save/"
validate = true