基本完成seq2seq基础功能

This commit is contained in:
linzehui 2020-05-03 15:52:10 +08:00
parent 15360e9724
commit b95aa56afb
10 changed files with 908 additions and 833 deletions

View File

@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models
"""
__all__ = [
"CNNText",
"SeqLabeling",
"AdvSeqLabel",
"BiLSTMCRF",
"ESIM",
"StarTransEnc",
"STSeqLabel",
"STNLICls",
"STSeqCls",
"BiaffineParser",
"GraphParser",
@ -30,7 +30,9 @@ __all__ = [
"BertForTokenClassification",
"BertForQuestionAnswering",
"TransformerSeq2SeqModel"
"TransformerSeq2SeqModel",
"LSTMSeq2SeqModel",
"BaseSeq2SeqModel"
]
from .base_model import BaseModel
@ -41,7 +43,8 @@ from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
from .snli import ESIM
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
from .seq2seq_model import TransformerSeq2SeqModel
from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, BaseSeq2SeqModel
import sys
from ..doc_utils import doc_process
doc_process(sys.modules[__name__])
doc_process(sys.modules[__name__])

View File

@ -1,26 +1,153 @@
import torch.nn as nn
import torch
from typing import Union, Tuple
from torch import nn
import numpy as np
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast
from ..embeddings import StaticEmbedding
from ..modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, Seq2SeqEncoder, LSTMSeq2SeqEncoder
from ..modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder
from ..core import Vocabulary
import argparse
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False):
super().__init__()
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
bind_input_output_embed)
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''
self.num_layers = num_layers
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
def forward(self, words, target, seq_len):
encoder_output, encoder_mask = self.encoder(words, seq_len)
past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
outputs = self.decoder(target, past, return_attention=False)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
return outputs
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.
return torch.FloatTensor(sinusoid_table)
def build_embedding(vocab, embed_dim, model_dir_or_name=None):
"""
todo: 根据需求可丰富该函数的功能目前只返回StaticEmbedding
:param vocab: Vocabulary
:param embed_dim:
:param model_dir_or_name:
:return:
"""
assert isinstance(vocab, Vocabulary)
embed = StaticEmbedding(vocab=vocab, embedding_dim=embed_dim, model_dir_or_name=model_dir_or_name)
return embed
class BaseSeq2SeqModel(nn.Module):
def __init__(self, encoder, decoder):
super(BaseSeq2SeqModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, Seq2SeqEncoder)
assert isinstance(self.decoder, Seq2SeqDecoder)
def forward(self, src_words, src_seq_len, tgt_prev_words):
encoder_output, encoder_mask = self.encoder(src_words, src_seq_len)
decoder_output = self.decoder(tgt_prev_words, encoder_output, encoder_mask)
return {'tgt_output': decoder_output}
class LSTMSeq2SeqModel(BaseSeq2SeqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--embedding_dim', type=int, default=300)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_size', type=int, default=300)
parser.add_argument('--bidirectional', action='store_true', default=True)
args = parser.parse_args()
return args
@classmethod
def build_model(cls, args, src_vocab, tgt_vocab):
# 处理embedding
src_embed = build_embedding(src_vocab, args.embedding_dim)
if args.share_embedding:
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
tgt_embed = src_embed
else:
tgt_embed = build_embedding(tgt_vocab, args.embedding_dim)
if args.bind_input_output_embed:
output_embed = nn.Parameter(tgt_embed.embedding.weight)
else:
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), args.embedding_dim), requires_grad=True)
nn.init.normal_(output_embed, mean=0, std=args.embedding_dim ** -0.5)
encoder = LSTMSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, num_layers=args.num_layers,
hidden_size=args.hidden_size, dropout=args.dropout,
bidirectional=args.bidirectional)
decoder = LSTMSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, num_layers=args.num_layers,
hidden_size=args.hidden_size, dropout=args.dropout, output_embed=output_embed,
attention=True)
return LSTMSeq2SeqModel(encoder, decoder)
class TransformerSeq2SeqModel(BaseSeq2SeqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--d_model', type=int, default=512)
parser.add_argument('--num_layers', type=int, default=6)
parser.add_argument('--n_head', type=int, default=8)
parser.add_argument('--dim_ff', type=int, default=2048)
parser.add_argument('--bind_input_output_embed', action='store_true', default=True)
parser.add_argument('--share_embedding', action='store_true', default=True)
args = parser.parse_args()
return args
@classmethod
def build_model(cls, args, src_vocab, tgt_vocab):
d_model = args.d_model
args.max_positions = getattr(args, 'max_positions', 1024) # 处理的最长长度
# 处理embedding
src_embed = build_embedding(src_vocab, d_model)
if args.share_embedding:
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
tgt_embed = src_embed
else:
tgt_embed = build_embedding(tgt_vocab, d_model)
if args.bind_input_output_embed:
output_embed = nn.Parameter(tgt_embed.embedding.weight)
else:
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), d_model), requires_grad=True)
nn.init.normal_(output_embed, mean=0, std=d_model ** -0.5)
pos_embed = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(args.max_positions + 1, d_model, padding_idx=0),
freeze=True) # 这里规定0是padding
encoder = TransformerSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, pos_embed=pos_embed,
num_layers=args.num_layers, d_model=args.d_model,
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout)
decoder = TransformerSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, pos_embed=pos_embed,
num_layers=args.num_layers, d_model=args.d_model,
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout,
output_embed=output_embed)
return TransformerSeq2SeqModel(encoder, decoder)

View File

@ -51,15 +51,17 @@ __all__ = [
'summary',
"BiLSTMEncoder",
"TransformerSeq2SeqEncoder",
"LSTMSeq2SeqEncoder",
"Seq2SeqEncoder",
"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"LSTMSeq2SeqDecoder",
"Seq2SeqDecoder",
"TransformerPast",
"Decoder",
"LSTMPast",
"Past"
]

View File

@ -9,13 +9,15 @@ __all__ = [
"allowed_transitions",
"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"TransformerPast",
"Decoder",
"Past",
"TransformerSeq2SeqDecoder",
"LSTMSeq2SeqDecoder",
"Seq2SeqDecoder"
]
from .crf import ConditionalRandomField
@ -23,4 +25,5 @@ from .crf import allowed_transitions
from .mlp import MLP
from .utils import viterbi_decode
from .seq2seq_generator import SequenceGenerator
from .seq2seq_decoder import *
from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMPast, TransformerPast, \
Past

File diff suppressed because it is too large Load Diff

View File

@ -2,23 +2,29 @@ __all__ = [
"SequenceGenerator"
]
import torch
from .seq2seq_decoder import Decoder
from ...models.seq2seq_model import BaseSeq2SeqModel
from ..encoder.seq2seq_encoder import Seq2SeqEncoder
from ..decoder.seq2seq_decoder import Seq2SeqDecoder
import torch.nn.functional as F
from ...core.utils import _get_model_device
from functools import partial
from ...core import Vocabulary
class SequenceGenerator:
def __init__(self, decoder: Decoder, max_length=20, num_beams=1,
def __init__(self, encoder: Seq2SeqEncoder = None, decoder: Seq2SeqDecoder = None,
max_length=20, num_beams=1,
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0):
if do_sample:
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length,
num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
length_penalty=length_penalty)
else:
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty)
@ -32,30 +38,45 @@ class SequenceGenerator:
self.eos_token_id = eos_token_id
self.repetition_penalty = repetition_penalty
self.length_penalty = length_penalty
# self.vocab = tgt_vocab
self.encoder = encoder
self.decoder = decoder
@torch.no_grad()
def generate(self, tokens=None, past=None):
def generate(self, src_tokens: torch.Tensor = None, src_seq_len: torch.Tensor = None, prev_tokens=None):
"""
:param torch.LongTensor tokens: batch_size x length, 开始的token
:param past:
:param src_tokens:
:param src_seq_len:
:param prev_tokens:
:return:
"""
# TODO 需要查看如果tokens长度不是1decode的时候是否还能够直接decode
return self.generate_func(tokens=tokens, past=past)
if self.encoder is not None:
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len)
else:
encoder_output = encoder_mask = None
# 每次都初始化past
if encoder_output is not None:
self.decoder.init_past(encoder_output, encoder_mask)
else:
self.decoder.init_past()
return self.generate_func(src_tokens, src_seq_len, prev_tokens)
@torch.no_grad()
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
def greedy_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
prev_tokens=None, max_length=20, num_beams=1,
bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0):
"""
贪婪地搜索句子
:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值如果为None则自动从bos_token_id开始生成
:param Past past: 应该包好encoder的一些输出
:param decoder:
:param encoder_output:
:param encoder_mask:
:param prev_tokens: batch_size x len, decode的输入值如果为None则自动从bos_token_id开始生成
:param int max_length: 生成句子的最大长度
:param int num_beams: 使用多大的beam进行解码
:param int bos_token_id: 如果tokens传入为None则使用bos_token_id开始往后解码
@ -65,11 +86,18 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
:return:
"""
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1,
token_ids = _no_beam_search_generate(decoder=decoder,
encoder_output=encoder_output, encoder_mask=encoder_mask,
prev_tokens=prev_tokens,
max_length=max_length, temperature=1,
top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
else:
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
token_ids = _beam_search_generate(decoder=decoder,
encoder_output=encoder_output, encoder_mask=encoder_mask,
prev_tokens=prev_tokens, max_length=max_length,
num_beams=num_beams,
temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
@ -78,14 +106,17 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
@torch.no_grad()
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
def sample_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
prev_tokens=None, max_length=20, num_beams=1,
temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0):
"""
使用采样的方法生成句子
:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值如果为None则自动从bos_token_id开始生成
:param Past past: 应该包好encoder的一些输出
:param decoder
:param encoder_output:
:param encoder_mask:
:param torch.LongTensor prev_tokens: batch_size x len, decode的输入值如果为None则自动从bos_token_id开始生成
:param int max_length: 生成句子的最大长度
:param int num_beam: 使用多大的beam进行解码
:param float temperature: 采样时的退火大小
@ -99,50 +130,55 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
"""
# 每个位置在生成的时候会sample生成
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature,
token_ids = _no_beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
prev_tokens=prev_tokens, max_length=max_length,
temperature=temperature,
top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
else:
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
token_ids = _beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
prev_tokens=prev_tokens, max_length=max_length,
num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
return token_ids
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
def _no_beam_search_generate(decoder: Seq2SeqDecoder,
encoder_output=None, encoder_mask: torch.Tensor = None,
prev_tokens: torch.Tensor = None, max_length=20,
temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
repetition_penalty=1.0, length_penalty=1.0):
if encoder_output is not None:
batch_size = encoder_output.size(0)
else:
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
batch_size = prev_tokens.size(0)
device = _get_model_device(decoder)
if tokens is None:
if prev_tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
if past is None:
raise RuntimeError("You have to specify either `past` or `tokens`.")
batch_size = past.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `past`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if past is not None:
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
if eos_token_id is None:
_eos_token_id = float('nan')
else:
_eos_token_id = eos_token_id
for i in range(tokens.size(1)):
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
for i in range(prev_tokens.size(1)): # 先过一遍pretoken做初始化
decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
token_ids = tokens.clone()
token_ids = prev_tokens.clone() # 保存所有生成的token
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)
# tokens = tokens[:, -1:]
while cur_len < max_length:
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size x vocab_size
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
@ -171,9 +207,9 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
next_tokens = torch.argmax(scores, dim=-1) # batch_size
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
tokens = next_tokens.unsqueeze(1)
next_tokens = next_tokens.unsqueeze(1)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
token_ids = torch.cat([token_ids, next_tokens], dim=-1) # batch_size x max_len
end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask)
@ -189,29 +225,31 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
return token_ids
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0,
def _beam_search_generate(decoder: Seq2SeqDecoder,
encoder_output=None, encoder_mask: torch.Tensor = None,
prev_tokens: torch.Tensor = None, max_length=20, num_beams=4, temperature=1.0,
top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor:
# 进行beam search
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
if past is None:
raise RuntimeError("You have to specify either `past` or `tokens`.")
batch_size = past.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `past`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if past is not None:
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
for i in range(tokens.size(1) - 1): # 如果输入的长度较长先decode
scores, past = decoder.decode(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
if encoder_output is not None:
batch_size = encoder_output.size(0)
else:
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
batch_size = prev_tokens.size(0)
device = _get_model_device(decoder)
if prev_tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
for i in range(prev_tokens.size(1)): # 如果输入的长度较长先decode
scores = decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
@ -225,15 +263,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
# 得到(batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
# 根据index来做顺序的调转
indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
decoder.reorder_past(indices, past)
decoder.reorder_past(indices)
prev_tokens = prev_tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
# 记录生成好的token (batch_size', cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
token_ids = torch.cat([prev_tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size
tokens = next_tokens.view(-1, 1)
beam_scores = next_scores.view(-1) # batch_size * num_beams
@ -247,7 +285,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
while cur_len < max_length:
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size * num_beams x vocab_size
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
@ -300,9 +338,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)
# 更改past状态, 重组token_ids
# 重组past/encoder状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
decoder.reorder_past(reorder_inds, past)
decoder.reorder_past(reorder_inds)
flag = True
if cur_len + 1 == max_length:
@ -327,8 +365,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
# 重新组织token_ids的状态
tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)
cur_tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), cur_tokens], dim=-1)
for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
@ -436,38 +474,3 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
if __name__ == '__main__':
# TODO 需要检查一下greedy_generate和sample_generate是否正常工作。
from torch import nn
class DummyDecoder(nn.Module):
def __init__(self, num_words):
super().__init__()
self.num_words = num_words
def decode(self, tokens, past):
batch_size = tokens.size(0)
return torch.randn(batch_size, self.num_words), past
def reorder_past(self, indices, past):
return past
num_words = 10
batch_size = 3
decoder = DummyDecoder(num_words)
tokens = greedy_generate(decoder=decoder, tokens=torch.zeros(batch_size, 1).long(), past=None, max_length=20,
num_beams=2,
bos_token_id=0, eos_token_id=num_words - 1,
repetition_penalty=1, length_penalty=1.0)
print(tokens)
tokens = sample_generate(decoder, tokens=torch.zeros(batch_size, 1).long(),
past=None, max_length=20, num_beams=2, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=0, eos_token_id=num_words - 1, repetition_penalty=1.0,
length_penalty=1.0)
print(tokens)

View File

@ -31,8 +31,9 @@ __all__ = [
"BiAttention",
"SelfAttention",
"BiLSTMEncoder",
"TransformerSeq2SeqEncoder"
"LSTMSeq2SeqEncoder",
"TransformerSeq2SeqEncoder",
"Seq2SeqEncoder"
]
from .attention import MultiHeadAttention, BiAttention, SelfAttention
@ -45,4 +46,4 @@ from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU
from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder
from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder

View File

@ -1,48 +1,238 @@
__all__ = [
"TransformerSeq2SeqEncoder",
"BiLSTMEncoder"
]
from torch import nn
import torch.nn as nn
import torch
from ...modules.encoder import LSTM
from ...core.utils import seq_len_to_mask
from torch.nn import TransformerEncoder
from torch.nn import LayerNorm
import torch.nn.functional as F
from typing import Union, Tuple
import numpy as np
from ...core.utils import seq_len_to_mask
import math
from ...core import Vocabulary
from ...modules import LSTM
class TransformerSeq2SeqEncoder(nn.Module):
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers: int = 6,
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
super(TransformerSeq2SeqEncoder, self).__init__()
self.embed = embed
self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head,dim_ff,dropout), num_layers)
class MultiheadAttention(nn.Module): # todo 这个要放哪里?
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
super(MultiheadAttention, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dropout = dropout
self.head_dim = d_model // n_head
self.layer_idx = layer_idx
assert d_model % n_head == 0, "d_model should be divisible by n_head"
self.scaling = self.head_dim ** -0.5
def forward(self, words, seq_len):
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.reset_parameters()
def forward(self, query, key, value, key_mask=None, attn_mask=None, past=None):
"""
:param words: batch, seq_len
:param seq_len:
:return: output: (batch, seq_len,dim) ; encoder_mask
:param query: batch x seq x dim
:param key:
:param value:
:param key_mask: batch x seq 用于指示哪些key不要attend到注意到mask为1的地方是要attend到的
:param attn_mask: seq x seq, 用于mask掉attention map 主要是用在训练时decoder端的self attention下三角为1
:param past: 过去的信息在inference的时候会用到比如encoder outputdecoder的prev kv这样可以减少计算
:return:
"""
words = self.embed(words) # batch, seq_len, dim
words = words.transpose(0, 1)
encoder_mask = seq_len_to_mask(seq_len) # batch, seq
words = self.transformer(words, src_key_padding_mask=~encoder_mask) # seq_len,batch,dim
assert key.size() == value.size()
if past is not None:
assert self.layer_idx is not None
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
return words.transpose(0, 1), encoder_mask
q = self.q_proj(query) # batch x seq x dim
q *= self.scaling
k = v = None
prev_k = prev_v = None
# 从past中取kv
if past is not None: # 说明此时在inference阶段
if qkv_same: # 此时在decoder self attention
prev_k = past.decoder_prev_key[self.layer_idx]
prev_v = past.decoder_prev_value[self.layer_idx]
else: # 此时在decoder-encoder attention直接将保存下来的key装载起来即可
k = past.encoder_key[self.layer_idx]
v = past.encoder_value[self.layer_idx]
if k is None:
k = self.k_proj(key)
v = self.v_proj(value)
if prev_k is not None:
k = torch.cat((prev_k, k), dim=1)
v = torch.cat((prev_v, v), dim=1)
# 更新past
if past is not None:
if qkv_same:
past.decoder_prev_key[self.layer_idx] = k
past.decoder_prev_value[self.layer_idx] = v
else:
past.encoder_key[self.layer_idx] = k
past.encoder_value[self.layer_idx] = v
# 开始计算attention
batch_size, q_len, d_model = query.size()
k_len, v_len = k.size(1), v.size(1)
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim)
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim)
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim)
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head
if key_mask is not None:
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,n_head
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf'))
if attn_mask is not None:
_attn_mask = ~attn_mask[None, :, :, None].bool() # 1,q_len,k_len,n_head
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf'))
attn_weights = F.softmax(attn_weights, dim=2)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
output = output.reshape(batch_size, q_len, -1)
output = self.out_proj(output) # batch,q_len,dim
return output, attn_weights
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
def set_layer_idx(self, layer_idx):
self.layer_idx = layer_idx
class BiLSTMEncoder(nn.Module):
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
class TransformerSeq2SeqEncoderLayer(nn.Module):
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048,
dropout: float = 0.1):
super(TransformerSeq2SeqEncoderLayer, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout
self.self_attn = MultiheadAttention(d_model, n_head, dropout)
self.attn_layer_norm = LayerNorm(d_model)
self.ffn_layer_norm = LayerNorm(d_model)
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(self.dim_ff, self.d_model),
nn.Dropout(dropout))
def forward(self, x, encoder_mask):
"""
:param x: batch,src_seq,dim
:param encoder_mask: batch,src_seq
:return:
"""
# attention
residual = x
x = self.attn_layer_norm(x)
x, _ = self.self_attn(query=x,
key=x,
value=x,
key_mask=encoder_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
# ffn
residual = x
x = self.ffn_layer_norm(x)
x = self.ffn(x)
x = residual + x
return x
class Seq2SeqEncoder(nn.Module):
def __init__(self, vocab):
super().__init__()
self.vocab = vocab
def forward(self, src_words, src_seq_len):
raise NotImplementedError
class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
def __init__(self, vocab: Vocabulary, embed: nn.Module, pos_embed: nn.Module = None, num_layers: int = 6,
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
super(TransformerSeq2SeqEncoder, self).__init__(vocab)
self.embed = embed
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
self.embed_scale = math.sqrt(d_model)
self.pos_embed = pos_embed
self.num_layers = num_layers
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
self.dropout = dropout
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout)
for _ in range(num_layers)])
self.layer_norm = LayerNorm(d_model)
def forward(self, src_words, src_seq_len):
"""
:param src_words: batch, src_seq_len
:param src_seq_len: [batch]
:return:
"""
batch_size, max_src_len = src_words.size()
device = src_words.device
x = self.embed(src_words) * self.embed_scale # batch, seq, dim
if self.pos_embed is not None:
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device)
x += self.pos_embed(position)
x = F.dropout(x, p=self.dropout, training=self.training)
encoder_mask = seq_len_to_mask(src_seq_len)
encoder_mask = encoder_mask.to(device)
for layer in self.layer_stacks:
x = layer(x, encoder_mask)
x = self.layer_norm(x)
return x, encoder_mask
class LSTMSeq2SeqEncoder(Seq2SeqEncoder):
def __init__(self, vocab: Vocabulary, embed: nn.Module, num_layers: int = 3, hidden_size: int = 400,
dropout: float = 0.3, bidirectional=True):
super().__init__(vocab)
self.embed = embed
self.num_layers = num_layers
self.dropout = dropout
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=bidirectional,
batch_first=True, dropout=dropout, num_layers=num_layers)
def forward(self, words, seq_len):
words = self.embed(words)
words, hx = self.lstm(words, seq_len)
def forward(self, src_words, src_seq_len):
batch_size = src_words.size(0)
device = src_words.device
x = self.embed(src_words)
x, (final_hidden, final_cell) = self.lstm(x, src_seq_len)
encoder_mask = seq_len_to_mask(src_seq_len).to(device)
return words, hx
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim
def concat_bidir(input):
output = input.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous()
return output.view(self.num_layers, batch_size, -1)
if self.bidirectional:
final_hidden = concat_bidir(final_hidden) # 将双向的hidden state拼接起来用于接下来的decoder的input
final_cell = concat_bidir(final_cell)
return (x, (final_hidden, final_cell)), encoder_mask # 为了配合Seq2SeqBaseModel的forward这边需要分为两个return

View File

@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer
__author__ = "Yu-Hsiang Huang"
def get_non_pad_mask(seq):
assert seq.dim() == 2
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''
@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
return torch.FloatTensor(sinusoid_table)
def get_attn_key_pad_mask(seq_k, seq_q):
''' For masking out the padding part of key sequence. '''
@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q):
return padding_mask
def get_subsequent_mask(seq):
''' For masking out the subsequent info. '''
@ -51,6 +55,7 @@ def get_subsequent_mask(seq):
return subsequent_mask
class Encoder(nn.Module):
''' A encoder model with self attention mechanism. '''
@ -98,6 +103,7 @@ class Encoder(nn.Module):
return enc_output, enc_slf_attn_list
return enc_output,
class Decoder(nn.Module):
''' A decoder model with self attention mechanism. '''
@ -152,6 +158,7 @@ class Decoder(nn.Module):
return dec_output, dec_slf_attn_list, dec_enc_attn_list
return dec_output,
class Transformer(nn.Module):
''' A sequence to sequence model with attention mechanism. '''
@ -181,8 +188,8 @@ class Transformer(nn.Module):
nn.init.xavier_normal_(self.tgt_word_prj.weight)
assert d_model == d_word_vec, \
'To facilitate the residual connections, \
the dimensions of all module outputs shall be the same.'
'To facilitate the residual connections, \
the dimensions of all module outputs shall be the same.'
if tgt_emb_prj_weight_sharing:
# Share the weight matrix between target word embedding & the final logit dense layer
@ -194,7 +201,7 @@ class Transformer(nn.Module):
if emb_src_tgt_weight_sharing:
# Share the weight matrix between source & target word embeddings
assert n_src_vocab == n_tgt_vocab, \
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):

View File

@ -2,8 +2,10 @@ import unittest
import torch
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \
LSTMSeq2SeqDecoder
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import seq_len_to_mask
@ -15,22 +17,17 @@ class TestTransformerSeq2SeqDecoder(unittest.TestCase):
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = TransformerSeq2SeqEncoder(embed)
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True)
args = TransformerSeq2SeqModel.add_args()
model = TransformerSeq2SeqModel.build_model(args, vocab, vocab)
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])
encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask)
output = model(src_words_idx, src_seq_len, tgt_words_idx)
print(output)
decoder_outputs = decoder(tgt_words_idx, past)
print(decoder_outputs)
print(mask)
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
# self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
def test_decode(self):
pass # todo