mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
add transformer
This commit is contained in:
parent
102259df39
commit
830d223344
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
import math
|
||||
from fastNLP.modules.utils import mask_softmax
|
||||
|
||||
|
||||
@ -17,3 +18,44 @@ class Attention(torch.nn.Module):
|
||||
|
||||
def _atten_forward(self, query, memory):
|
||||
raise NotImplementedError
|
||||
|
||||
class DotAtte(nn.Module):
|
||||
def __init__(self, key_size, value_size):
|
||||
super(DotAtte, self).__init__()
|
||||
self.key_size = key_size
|
||||
self.value_size = value_size
|
||||
self.scale = math.sqrt(key_size)
|
||||
|
||||
def forward(self, Q, K, V, seq_mask=None):
|
||||
"""
|
||||
|
||||
:param Q: [batch, seq_len, key_size]
|
||||
:param K: [batch, seq_len, key_size]
|
||||
:param V: [batch, seq_len, value_size]
|
||||
:param seq_mask: [batch, seq_len]
|
||||
"""
|
||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
|
||||
if seq_mask is not None:
|
||||
output.masked_fill_(seq_mask.lt(1), -float('inf'))
|
||||
output = nn.functional.softmax(output, dim=2)
|
||||
return torch.matmul(output, V)
|
||||
|
||||
class MultiHeadAtte(nn.Module):
|
||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
|
||||
super(MultiHeadAtte, self).__init__()
|
||||
self.in_linear = nn.ModuleList()
|
||||
for i in range(num_atte * 3):
|
||||
out_feat = key_size if (i % 3) != 2 else value_size
|
||||
self.in_linear.append(nn.Linear(input_size, out_feat))
|
||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
|
||||
self.out_linear = nn.Linear(value_size * num_atte, output_size)
|
||||
|
||||
def forward(self, Q, K, V, seq_mask=None):
|
||||
heads = []
|
||||
for i in range(len(self.attes)):
|
||||
j = i * 3
|
||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
|
||||
headi = self.attes[i](qi, ki, vi, seq_mask)
|
||||
heads.append(headi)
|
||||
output = torch.cat(heads, dim=2)
|
||||
return self.out_linear(output)
|
||||
|
32
fastNLP/modules/encoder/transformer.py
Normal file
32
fastNLP/modules/encoder/transformer.py
Normal file
@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..aggregator.attention import MultiHeadAtte
|
||||
from ..other_modules import LayerNormalization
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
class SubLayer(nn.Module):
|
||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
|
||||
super(TransformerEncoder.SubLayer, self).__init__()
|
||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
|
||||
self.norm1 = LayerNormalization(output_size)
|
||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_size, output_size))
|
||||
self.norm2 = LayerNormalization(output_size)
|
||||
|
||||
def forward(self, input, seq_mask):
|
||||
attention = self.atte(input)
|
||||
norm_atte = self.norm1(attention + input)
|
||||
output = self.ffn(norm_atte)
|
||||
return self.norm2(output + norm_atte)
|
||||
|
||||
def __init__(self, num_layers, **kargs):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, seq_mask=None):
|
||||
return self.layers(x, seq_mask)
|
||||
|
||||
|
@ -31,12 +31,12 @@ class GroupNorm(nn.Module):
|
||||
class LayerNormalization(nn.Module):
|
||||
""" Layer normalization module """
|
||||
|
||||
def __init__(self, d_hid, eps=1e-3):
|
||||
def __init__(self, layer_size, eps=1e-3):
|
||||
super(LayerNormalization, self).__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
|
||||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
|
||||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True))
|
||||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True))
|
||||
|
||||
def forward(self, z):
|
||||
if z.size(1) == 1:
|
||||
@ -44,9 +44,8 @@ class LayerNormalization(nn.Module):
|
||||
|
||||
mu = torch.mean(z, keepdim=True, dim=-1)
|
||||
sigma = torch.std(z, keepdim=True, dim=-1)
|
||||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
|
||||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
|
||||
|
||||
ln_out = (z - mu) / (sigma + self.eps)
|
||||
ln_out = ln_out * self.a_2 + self.b_2
|
||||
return ln_out
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
[train]
|
||||
epochs = 50
|
||||
epochs = -1
|
||||
batch_size = 16
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
|
Loading…
Reference in New Issue
Block a user