mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
基本完成seq2seq基础功能
This commit is contained in:
parent
15360e9724
commit
b95aa56afb
@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models
|
|||||||
"""
|
"""
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CNNText",
|
"CNNText",
|
||||||
|
|
||||||
"SeqLabeling",
|
"SeqLabeling",
|
||||||
"AdvSeqLabel",
|
"AdvSeqLabel",
|
||||||
"BiLSTMCRF",
|
"BiLSTMCRF",
|
||||||
|
|
||||||
"ESIM",
|
"ESIM",
|
||||||
|
|
||||||
"StarTransEnc",
|
"StarTransEnc",
|
||||||
"STSeqLabel",
|
"STSeqLabel",
|
||||||
"STNLICls",
|
"STNLICls",
|
||||||
"STSeqCls",
|
"STSeqCls",
|
||||||
|
|
||||||
"BiaffineParser",
|
"BiaffineParser",
|
||||||
"GraphParser",
|
"GraphParser",
|
||||||
|
|
||||||
@ -30,7 +30,9 @@ __all__ = [
|
|||||||
"BertForTokenClassification",
|
"BertForTokenClassification",
|
||||||
"BertForQuestionAnswering",
|
"BertForQuestionAnswering",
|
||||||
|
|
||||||
"TransformerSeq2SeqModel"
|
"TransformerSeq2SeqModel",
|
||||||
|
"LSTMSeq2SeqModel",
|
||||||
|
"BaseSeq2SeqModel"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
@ -41,7 +43,8 @@ from .cnn_text_classification import CNNText
|
|||||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
|
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
|
||||||
from .snli import ESIM
|
from .snli import ESIM
|
||||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
|
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
|
||||||
from .seq2seq_model import TransformerSeq2SeqModel
|
from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, BaseSeq2SeqModel
|
||||||
import sys
|
import sys
|
||||||
from ..doc_utils import doc_process
|
from ..doc_utils import doc_process
|
||||||
doc_process(sys.modules[__name__])
|
|
||||||
|
doc_process(sys.modules[__name__])
|
||||||
|
@ -1,26 +1,153 @@
|
|||||||
import torch.nn as nn
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union, Tuple
|
from torch import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast
|
from ..embeddings import StaticEmbedding
|
||||||
|
from ..modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, Seq2SeqEncoder, LSTMSeq2SeqEncoder
|
||||||
|
from ..modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder
|
||||||
|
from ..core import Vocabulary
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法
|
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||||
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
|
''' Sinusoid position encoding table '''
|
||||||
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
|
|
||||||
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
|
|
||||||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
|
|
||||||
bind_input_output_embed=False):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
|
|
||||||
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
|
|
||||||
bind_input_output_embed)
|
|
||||||
|
|
||||||
self.num_layers = num_layers
|
def cal_angle(position, hid_idx):
|
||||||
|
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
|
||||||
|
|
||||||
def forward(self, words, target, seq_len):
|
def get_posi_angle_vec(position):
|
||||||
encoder_output, encoder_mask = self.encoder(words, seq_len)
|
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
|
||||||
past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
|
|
||||||
outputs = self.decoder(target, past, return_attention=False)
|
|
||||||
|
|
||||||
return outputs
|
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||||
|
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
if padding_idx is not None:
|
||||||
|
# zero vector for padding dimension
|
||||||
|
sinusoid_table[padding_idx] = 0.
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table)
|
||||||
|
|
||||||
|
|
||||||
|
def build_embedding(vocab, embed_dim, model_dir_or_name=None):
|
||||||
|
"""
|
||||||
|
todo: 根据需求可丰富该函数的功能,目前只返回StaticEmbedding
|
||||||
|
:param vocab: Vocabulary
|
||||||
|
:param embed_dim:
|
||||||
|
:param model_dir_or_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert isinstance(vocab, Vocabulary)
|
||||||
|
embed = StaticEmbedding(vocab=vocab, embedding_dim=embed_dim, model_dir_or_name=model_dir_or_name)
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSeq2SeqModel(nn.Module):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super(BaseSeq2SeqModel, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
assert isinstance(self.encoder, Seq2SeqEncoder)
|
||||||
|
assert isinstance(self.decoder, Seq2SeqDecoder)
|
||||||
|
|
||||||
|
def forward(self, src_words, src_seq_len, tgt_prev_words):
|
||||||
|
encoder_output, encoder_mask = self.encoder(src_words, src_seq_len)
|
||||||
|
decoder_output = self.decoder(tgt_prev_words, encoder_output, encoder_mask)
|
||||||
|
|
||||||
|
return {'tgt_output': decoder_output}
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMSeq2SeqModel(BaseSeq2SeqModel):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super().__init__(encoder, decoder)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.3)
|
||||||
|
parser.add_argument('--embedding_dim', type=int, default=300)
|
||||||
|
parser.add_argument('--num_layers', type=int, default=3)
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=300)
|
||||||
|
parser.add_argument('--bidirectional', action='store_true', default=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, src_vocab, tgt_vocab):
|
||||||
|
# 处理embedding
|
||||||
|
src_embed = build_embedding(src_vocab, args.embedding_dim)
|
||||||
|
if args.share_embedding:
|
||||||
|
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
|
||||||
|
tgt_embed = src_embed
|
||||||
|
else:
|
||||||
|
tgt_embed = build_embedding(tgt_vocab, args.embedding_dim)
|
||||||
|
|
||||||
|
if args.bind_input_output_embed:
|
||||||
|
output_embed = nn.Parameter(tgt_embed.embedding.weight)
|
||||||
|
else:
|
||||||
|
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), args.embedding_dim), requires_grad=True)
|
||||||
|
nn.init.normal_(output_embed, mean=0, std=args.embedding_dim ** -0.5)
|
||||||
|
|
||||||
|
encoder = LSTMSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, num_layers=args.num_layers,
|
||||||
|
hidden_size=args.hidden_size, dropout=args.dropout,
|
||||||
|
bidirectional=args.bidirectional)
|
||||||
|
decoder = LSTMSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, num_layers=args.num_layers,
|
||||||
|
hidden_size=args.hidden_size, dropout=args.dropout, output_embed=output_embed,
|
||||||
|
attention=True)
|
||||||
|
|
||||||
|
return LSTMSeq2SeqModel(encoder, decoder)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSeq2SeqModel(BaseSeq2SeqModel):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super().__init__(encoder, decoder)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.1)
|
||||||
|
parser.add_argument('--d_model', type=int, default=512)
|
||||||
|
parser.add_argument('--num_layers', type=int, default=6)
|
||||||
|
parser.add_argument('--n_head', type=int, default=8)
|
||||||
|
parser.add_argument('--dim_ff', type=int, default=2048)
|
||||||
|
parser.add_argument('--bind_input_output_embed', action='store_true', default=True)
|
||||||
|
parser.add_argument('--share_embedding', action='store_true', default=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, src_vocab, tgt_vocab):
|
||||||
|
d_model = args.d_model
|
||||||
|
args.max_positions = getattr(args, 'max_positions', 1024) # 处理的最长长度
|
||||||
|
|
||||||
|
# 处理embedding
|
||||||
|
src_embed = build_embedding(src_vocab, d_model)
|
||||||
|
if args.share_embedding:
|
||||||
|
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
|
||||||
|
tgt_embed = src_embed
|
||||||
|
else:
|
||||||
|
tgt_embed = build_embedding(tgt_vocab, d_model)
|
||||||
|
|
||||||
|
if args.bind_input_output_embed:
|
||||||
|
output_embed = nn.Parameter(tgt_embed.embedding.weight)
|
||||||
|
else:
|
||||||
|
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), d_model), requires_grad=True)
|
||||||
|
nn.init.normal_(output_embed, mean=0, std=d_model ** -0.5)
|
||||||
|
|
||||||
|
pos_embed = nn.Embedding.from_pretrained(
|
||||||
|
get_sinusoid_encoding_table(args.max_positions + 1, d_model, padding_idx=0),
|
||||||
|
freeze=True) # 这里规定0是padding
|
||||||
|
|
||||||
|
encoder = TransformerSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, pos_embed=pos_embed,
|
||||||
|
num_layers=args.num_layers, d_model=args.d_model,
|
||||||
|
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout)
|
||||||
|
decoder = TransformerSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, pos_embed=pos_embed,
|
||||||
|
num_layers=args.num_layers, d_model=args.d_model,
|
||||||
|
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout,
|
||||||
|
output_embed=output_embed)
|
||||||
|
|
||||||
|
return TransformerSeq2SeqModel(encoder, decoder)
|
||||||
|
@ -51,15 +51,17 @@ __all__ = [
|
|||||||
|
|
||||||
'summary',
|
'summary',
|
||||||
|
|
||||||
"BiLSTMEncoder",
|
|
||||||
"TransformerSeq2SeqEncoder",
|
"TransformerSeq2SeqEncoder",
|
||||||
|
"LSTMSeq2SeqEncoder",
|
||||||
|
"Seq2SeqEncoder",
|
||||||
|
|
||||||
"SequenceGenerator",
|
"SequenceGenerator",
|
||||||
"LSTMDecoder",
|
|
||||||
"LSTMPast",
|
|
||||||
"TransformerSeq2SeqDecoder",
|
"TransformerSeq2SeqDecoder",
|
||||||
|
"LSTMSeq2SeqDecoder",
|
||||||
|
"Seq2SeqDecoder",
|
||||||
|
|
||||||
"TransformerPast",
|
"TransformerPast",
|
||||||
"Decoder",
|
"LSTMPast",
|
||||||
"Past"
|
"Past"
|
||||||
|
|
||||||
]
|
]
|
||||||
|
@ -9,13 +9,15 @@ __all__ = [
|
|||||||
"allowed_transitions",
|
"allowed_transitions",
|
||||||
|
|
||||||
"SequenceGenerator",
|
"SequenceGenerator",
|
||||||
"LSTMDecoder",
|
|
||||||
"LSTMPast",
|
"LSTMPast",
|
||||||
"TransformerSeq2SeqDecoder",
|
|
||||||
"TransformerPast",
|
"TransformerPast",
|
||||||
"Decoder",
|
|
||||||
"Past",
|
"Past",
|
||||||
|
|
||||||
|
"TransformerSeq2SeqDecoder",
|
||||||
|
"LSTMSeq2SeqDecoder",
|
||||||
|
"Seq2SeqDecoder"
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
from .crf import ConditionalRandomField
|
from .crf import ConditionalRandomField
|
||||||
@ -23,4 +25,5 @@ from .crf import allowed_transitions
|
|||||||
from .mlp import MLP
|
from .mlp import MLP
|
||||||
from .utils import viterbi_decode
|
from .utils import viterbi_decode
|
||||||
from .seq2seq_generator import SequenceGenerator
|
from .seq2seq_generator import SequenceGenerator
|
||||||
from .seq2seq_decoder import *
|
from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMPast, TransformerPast, \
|
||||||
|
Past
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -2,23 +2,29 @@ __all__ = [
|
|||||||
"SequenceGenerator"
|
"SequenceGenerator"
|
||||||
]
|
]
|
||||||
import torch
|
import torch
|
||||||
from .seq2seq_decoder import Decoder
|
from ...models.seq2seq_model import BaseSeq2SeqModel
|
||||||
|
from ..encoder.seq2seq_encoder import Seq2SeqEncoder
|
||||||
|
from ..decoder.seq2seq_decoder import Seq2SeqDecoder
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from ...core.utils import _get_model_device
|
from ...core.utils import _get_model_device
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from ...core import Vocabulary
|
||||||
|
|
||||||
|
|
||||||
class SequenceGenerator:
|
class SequenceGenerator:
|
||||||
def __init__(self, decoder: Decoder, max_length=20, num_beams=1,
|
def __init__(self, encoder: Seq2SeqEncoder = None, decoder: Seq2SeqDecoder = None,
|
||||||
|
max_length=20, num_beams=1,
|
||||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
|
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
|
||||||
repetition_penalty=1, length_penalty=1.0):
|
repetition_penalty=1, length_penalty=1.0):
|
||||||
if do_sample:
|
if do_sample:
|
||||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
|
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
|
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
|
||||||
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
|
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
|
||||||
length_penalty=length_penalty)
|
length_penalty=length_penalty)
|
||||||
else:
|
else:
|
||||||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
|
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
length_penalty=length_penalty)
|
length_penalty=length_penalty)
|
||||||
@ -32,30 +38,45 @@ class SequenceGenerator:
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.repetition_penalty = repetition_penalty
|
self.repetition_penalty = repetition_penalty
|
||||||
self.length_penalty = length_penalty
|
self.length_penalty = length_penalty
|
||||||
|
# self.vocab = tgt_vocab
|
||||||
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, tokens=None, past=None):
|
def generate(self, src_tokens: torch.Tensor = None, src_seq_len: torch.Tensor = None, prev_tokens=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param torch.LongTensor tokens: batch_size x length, 开始的token
|
:param src_tokens:
|
||||||
:param past:
|
:param src_seq_len:
|
||||||
|
:param prev_tokens:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode?
|
if self.encoder is not None:
|
||||||
return self.generate_func(tokens=tokens, past=past)
|
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len)
|
||||||
|
else:
|
||||||
|
encoder_output = encoder_mask = None
|
||||||
|
|
||||||
|
# 每次都初始化past
|
||||||
|
if encoder_output is not None:
|
||||||
|
self.decoder.init_past(encoder_output, encoder_mask)
|
||||||
|
else:
|
||||||
|
self.decoder.init_past()
|
||||||
|
return self.generate_func(src_tokens, src_seq_len, prev_tokens)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
def greedy_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
|
||||||
|
prev_tokens=None, max_length=20, num_beams=1,
|
||||||
bos_token_id=None, eos_token_id=None,
|
bos_token_id=None, eos_token_id=None,
|
||||||
repetition_penalty=1, length_penalty=1.0):
|
repetition_penalty=1, length_penalty=1.0):
|
||||||
"""
|
"""
|
||||||
贪婪地搜索句子
|
贪婪地搜索句子
|
||||||
|
|
||||||
:param Decoder decoder: Decoder对象
|
|
||||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
:param decoder:
|
||||||
:param Past past: 应该包好encoder的一些输出。
|
:param encoder_output:
|
||||||
|
:param encoder_mask:
|
||||||
|
:param prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||||
:param int max_length: 生成句子的最大长度。
|
:param int max_length: 生成句子的最大长度。
|
||||||
:param int num_beams: 使用多大的beam进行解码。
|
:param int num_beams: 使用多大的beam进行解码。
|
||||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
|
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
|
||||||
@ -65,11 +86,18 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if num_beams == 1:
|
if num_beams == 1:
|
||||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1,
|
token_ids = _no_beam_search_generate(decoder=decoder,
|
||||||
|
encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||||
|
prev_tokens=prev_tokens,
|
||||||
|
max_length=max_length, temperature=1,
|
||||||
|
top_k=50, top_p=1,
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
||||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||||
else:
|
else:
|
||||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
|
token_ids = _beam_search_generate(decoder=decoder,
|
||||||
|
encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||||
|
prev_tokens=prev_tokens, max_length=max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
temperature=1, top_k=50, top_p=1,
|
temperature=1, top_k=50, top_p=1,
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
||||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||||
@ -78,14 +106,17 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
|
def sample_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
|
||||||
|
prev_tokens=None, max_length=20, num_beams=1,
|
||||||
|
temperature=1.0, top_k=50,
|
||||||
top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0):
|
top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0):
|
||||||
"""
|
"""
|
||||||
使用采样的方法生成句子
|
使用采样的方法生成句子
|
||||||
|
|
||||||
:param Decoder decoder: Decoder对象
|
:param decoder
|
||||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
:param encoder_output:
|
||||||
:param Past past: 应该包好encoder的一些输出。
|
:param encoder_mask:
|
||||||
|
:param torch.LongTensor prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||||
:param int max_length: 生成句子的最大长度。
|
:param int max_length: 生成句子的最大长度。
|
||||||
:param int num_beam: 使用多大的beam进行解码。
|
:param int num_beam: 使用多大的beam进行解码。
|
||||||
:param float temperature: 采样时的退火大小
|
:param float temperature: 采样时的退火大小
|
||||||
@ -99,50 +130,55 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
|||||||
"""
|
"""
|
||||||
# 每个位置在生成的时候会sample生成
|
# 每个位置在生成的时候会sample生成
|
||||||
if num_beams == 1:
|
if num_beams == 1:
|
||||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature,
|
token_ids = _no_beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||||
|
prev_tokens=prev_tokens, max_length=max_length,
|
||||||
|
temperature=temperature,
|
||||||
top_k=top_k, top_p=top_p,
|
top_k=top_k, top_p=top_p,
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
||||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||||
else:
|
else:
|
||||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
|
token_ids = _beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||||
|
prev_tokens=prev_tokens, max_length=max_length,
|
||||||
|
num_beams=num_beams,
|
||||||
temperature=temperature, top_k=top_k, top_p=top_p,
|
temperature=temperature, top_k=top_k, top_p=top_p,
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
||||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50,
|
def _no_beam_search_generate(decoder: Seq2SeqDecoder,
|
||||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
|
encoder_output=None, encoder_mask: torch.Tensor = None,
|
||||||
|
prev_tokens: torch.Tensor = None, max_length=20,
|
||||||
|
temperature=1.0, top_k=50,
|
||||||
|
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
|
||||||
repetition_penalty=1.0, length_penalty=1.0):
|
repetition_penalty=1.0, length_penalty=1.0):
|
||||||
|
if encoder_output is not None:
|
||||||
|
batch_size = encoder_output.size(0)
|
||||||
|
else:
|
||||||
|
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
|
||||||
|
batch_size = prev_tokens.size(0)
|
||||||
device = _get_model_device(decoder)
|
device = _get_model_device(decoder)
|
||||||
if tokens is None:
|
|
||||||
|
if prev_tokens is None:
|
||||||
if bos_token_id is None:
|
if bos_token_id is None:
|
||||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
|
||||||
if past is None:
|
|
||||||
raise RuntimeError("You have to specify either `past` or `tokens`.")
|
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||||
batch_size = past.num_samples()
|
|
||||||
if batch_size is None:
|
|
||||||
raise RuntimeError("Cannot infer the number of samples from `past`.")
|
|
||||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
|
||||||
batch_size = tokens.size(0)
|
|
||||||
if past is not None:
|
|
||||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
|
|
||||||
|
|
||||||
if eos_token_id is None:
|
if eos_token_id is None:
|
||||||
_eos_token_id = float('nan')
|
_eos_token_id = float('nan')
|
||||||
else:
|
else:
|
||||||
_eos_token_id = eos_token_id
|
_eos_token_id = eos_token_id
|
||||||
|
|
||||||
for i in range(tokens.size(1)):
|
for i in range(prev_tokens.size(1)): # 先过一遍pretoken,做初始化
|
||||||
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
|
decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
|
||||||
|
|
||||||
token_ids = tokens.clone()
|
token_ids = prev_tokens.clone() # 保存所有生成的token
|
||||||
cur_len = token_ids.size(1)
|
cur_len = token_ids.size(1)
|
||||||
dones = token_ids.new_zeros(batch_size).eq(1)
|
dones = token_ids.new_zeros(batch_size).eq(1)
|
||||||
# tokens = tokens[:, -1:]
|
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past
|
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size x vocab_size
|
||||||
|
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
token_scores = scores.gather(dim=1, index=token_ids)
|
token_scores = scores.gather(dim=1, index=token_ids)
|
||||||
@ -171,9 +207,9 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
|
|||||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size
|
next_tokens = torch.argmax(scores, dim=-1) # batch_size
|
||||||
|
|
||||||
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
|
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
|
||||||
tokens = next_tokens.unsqueeze(1)
|
next_tokens = next_tokens.unsqueeze(1)
|
||||||
|
|
||||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
token_ids = torch.cat([token_ids, next_tokens], dim=-1) # batch_size x max_len
|
||||||
|
|
||||||
end_mask = next_tokens.eq(_eos_token_id)
|
end_mask = next_tokens.eq(_eos_token_id)
|
||||||
dones = dones.__or__(end_mask)
|
dones = dones.__or__(end_mask)
|
||||||
@ -189,29 +225,31 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
|
|||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0,
|
def _beam_search_generate(decoder: Seq2SeqDecoder,
|
||||||
|
encoder_output=None, encoder_mask: torch.Tensor = None,
|
||||||
|
prev_tokens: torch.Tensor = None, max_length=20, num_beams=4, temperature=1.0,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
|
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
|
||||||
repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor:
|
repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor:
|
||||||
# 进行beam search
|
# 进行beam search
|
||||||
device = _get_model_device(decoder)
|
|
||||||
if tokens is None:
|
|
||||||
if bos_token_id is None:
|
|
||||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
|
||||||
if past is None:
|
|
||||||
raise RuntimeError("You have to specify either `past` or `tokens`.")
|
|
||||||
batch_size = past.num_samples()
|
|
||||||
if batch_size is None:
|
|
||||||
raise RuntimeError("Cannot infer the number of samples from `past`.")
|
|
||||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
|
||||||
batch_size = tokens.size(0)
|
|
||||||
if past is not None:
|
|
||||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
|
|
||||||
|
|
||||||
for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
|
if encoder_output is not None:
|
||||||
scores, past = decoder.decode(tokens[:, :i + 1],
|
batch_size = encoder_output.size(0)
|
||||||
past) # (batch_size, vocab_size), Past
|
else:
|
||||||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
|
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
|
||||||
|
batch_size = prev_tokens.size(0)
|
||||||
|
|
||||||
|
device = _get_model_device(decoder)
|
||||||
|
|
||||||
|
if prev_tokens is None:
|
||||||
|
if bos_token_id is None:
|
||||||
|
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
|
||||||
|
|
||||||
|
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
for i in range(prev_tokens.size(1)): # 如果输入的长度较长,先decode
|
||||||
|
scores = decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
|
||||||
|
|
||||||
vocab_size = scores.size(1)
|
vocab_size = scores.size(1)
|
||||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
|
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
|
||||||
|
|
||||||
@ -225,15 +263,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
|||||||
# 得到(batch_size, num_beams), (batch_size, num_beams)
|
# 得到(batch_size, num_beams), (batch_size, num_beams)
|
||||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
||||||
|
|
||||||
|
# 根据index来做顺序的调转
|
||||||
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
||||||
indices = indices.repeat_interleave(num_beams)
|
indices = indices.repeat_interleave(num_beams)
|
||||||
decoder.reorder_past(indices, past)
|
decoder.reorder_past(indices)
|
||||||
|
prev_tokens = prev_tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
|
||||||
|
|
||||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
|
|
||||||
# 记录生成好的token (batch_size', cur_len)
|
# 记录生成好的token (batch_size', cur_len)
|
||||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
|
token_ids = torch.cat([prev_tokens, next_tokens.view(-1, 1)], dim=-1)
|
||||||
dones = [False] * batch_size
|
dones = [False] * batch_size
|
||||||
tokens = next_tokens.view(-1, 1)
|
|
||||||
|
|
||||||
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
||||||
|
|
||||||
@ -247,7 +285,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
|||||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past
|
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size * num_beams x vocab_size
|
||||||
|
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
token_scores = scores.gather(dim=1, index=token_ids)
|
token_scores = scores.gather(dim=1, index=token_ids)
|
||||||
@ -300,9 +338,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
|||||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
||||||
beam_scores = _next_scores.view(-1)
|
beam_scores = _next_scores.view(-1)
|
||||||
|
|
||||||
# 更改past状态, 重组token_ids
|
# 重组past/encoder状态, 重组token_ids
|
||||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
|
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
|
||||||
decoder.reorder_past(reorder_inds, past)
|
decoder.reorder_past(reorder_inds)
|
||||||
|
|
||||||
flag = True
|
flag = True
|
||||||
if cur_len + 1 == max_length:
|
if cur_len + 1 == max_length:
|
||||||
@ -327,8 +365,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
|||||||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
||||||
|
|
||||||
# 重新组织token_ids的状态
|
# 重新组织token_ids的状态
|
||||||
tokens = _next_tokens
|
cur_tokens = _next_tokens
|
||||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)
|
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), cur_tokens], dim=-1)
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
|
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
|
||||||
@ -436,38 +474,3 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
|
|||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# TODO 需要检查一下greedy_generate和sample_generate是否正常工作。
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDecoder(nn.Module):
|
|
||||||
def __init__(self, num_words):
|
|
||||||
super().__init__()
|
|
||||||
self.num_words = num_words
|
|
||||||
|
|
||||||
def decode(self, tokens, past):
|
|
||||||
batch_size = tokens.size(0)
|
|
||||||
return torch.randn(batch_size, self.num_words), past
|
|
||||||
|
|
||||||
def reorder_past(self, indices, past):
|
|
||||||
return past
|
|
||||||
|
|
||||||
|
|
||||||
num_words = 10
|
|
||||||
batch_size = 3
|
|
||||||
decoder = DummyDecoder(num_words)
|
|
||||||
|
|
||||||
tokens = greedy_generate(decoder=decoder, tokens=torch.zeros(batch_size, 1).long(), past=None, max_length=20,
|
|
||||||
num_beams=2,
|
|
||||||
bos_token_id=0, eos_token_id=num_words - 1,
|
|
||||||
repetition_penalty=1, length_penalty=1.0)
|
|
||||||
print(tokens)
|
|
||||||
|
|
||||||
tokens = sample_generate(decoder, tokens=torch.zeros(batch_size, 1).long(),
|
|
||||||
past=None, max_length=20, num_beams=2, temperature=1.0, top_k=50,
|
|
||||||
top_p=1.0, bos_token_id=0, eos_token_id=num_words - 1, repetition_penalty=1.0,
|
|
||||||
length_penalty=1.0)
|
|
||||||
print(tokens)
|
|
||||||
|
@ -31,8 +31,9 @@ __all__ = [
|
|||||||
"BiAttention",
|
"BiAttention",
|
||||||
"SelfAttention",
|
"SelfAttention",
|
||||||
|
|
||||||
"BiLSTMEncoder",
|
"LSTMSeq2SeqEncoder",
|
||||||
"TransformerSeq2SeqEncoder"
|
"TransformerSeq2SeqEncoder",
|
||||||
|
"Seq2SeqEncoder"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .attention import MultiHeadAttention, BiAttention, SelfAttention
|
from .attention import MultiHeadAttention, BiAttention, SelfAttention
|
||||||
@ -45,4 +46,4 @@ from .star_transformer import StarTransformer
|
|||||||
from .transformer import TransformerEncoder
|
from .transformer import TransformerEncoder
|
||||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU
|
from .variational_rnn import VarRNN, VarLSTM, VarGRU
|
||||||
|
|
||||||
from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder
|
from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder
|
||||||
|
@ -1,48 +1,238 @@
|
|||||||
__all__ = [
|
import torch.nn as nn
|
||||||
"TransformerSeq2SeqEncoder",
|
|
||||||
"BiLSTMEncoder"
|
|
||||||
]
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
import torch
|
||||||
from ...modules.encoder import LSTM
|
from torch.nn import LayerNorm
|
||||||
from ...core.utils import seq_len_to_mask
|
import torch.nn.functional as F
|
||||||
from torch.nn import TransformerEncoder
|
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
import numpy as np
|
from ...core.utils import seq_len_to_mask
|
||||||
|
import math
|
||||||
|
from ...core import Vocabulary
|
||||||
|
from ...modules import LSTM
|
||||||
|
|
||||||
|
|
||||||
class TransformerSeq2SeqEncoder(nn.Module):
|
class MultiheadAttention(nn.Module): # todo 这个要放哪里?
|
||||||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers: int = 6,
|
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
|
||||||
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
|
super(MultiheadAttention, self).__init__()
|
||||||
super(TransformerSeq2SeqEncoder, self).__init__()
|
self.d_model = d_model
|
||||||
self.embed = embed
|
self.n_head = n_head
|
||||||
self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head,dim_ff,dropout), num_layers)
|
self.dropout = dropout
|
||||||
|
self.head_dim = d_model // n_head
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
assert d_model % n_head == 0, "d_model should be divisible by n_head"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
def forward(self, words, seq_len):
|
self.q_proj = nn.Linear(d_model, d_model)
|
||||||
|
self.k_proj = nn.Linear(d_model, d_model)
|
||||||
|
self.v_proj = nn.Linear(d_model, d_model)
|
||||||
|
self.out_proj = nn.Linear(d_model, d_model)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, query, key, value, key_mask=None, attn_mask=None, past=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param words: batch, seq_len
|
:param query: batch x seq x dim
|
||||||
:param seq_len:
|
:param key:
|
||||||
:return: output: (batch, seq_len,dim) ; encoder_mask
|
:param value:
|
||||||
|
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的
|
||||||
|
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1
|
||||||
|
:param past: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。
|
||||||
|
:return:
|
||||||
"""
|
"""
|
||||||
words = self.embed(words) # batch, seq_len, dim
|
assert key.size() == value.size()
|
||||||
words = words.transpose(0, 1)
|
if past is not None:
|
||||||
encoder_mask = seq_len_to_mask(seq_len) # batch, seq
|
assert self.layer_idx is not None
|
||||||
words = self.transformer(words, src_key_padding_mask=~encoder_mask) # seq_len,batch,dim
|
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
||||||
|
|
||||||
return words.transpose(0, 1), encoder_mask
|
q = self.q_proj(query) # batch x seq x dim
|
||||||
|
q *= self.scaling
|
||||||
|
k = v = None
|
||||||
|
prev_k = prev_v = None
|
||||||
|
|
||||||
|
# 从past中取kv
|
||||||
|
if past is not None: # 说明此时在inference阶段
|
||||||
|
if qkv_same: # 此时在decoder self attention
|
||||||
|
prev_k = past.decoder_prev_key[self.layer_idx]
|
||||||
|
prev_v = past.decoder_prev_value[self.layer_idx]
|
||||||
|
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可
|
||||||
|
k = past.encoder_key[self.layer_idx]
|
||||||
|
v = past.encoder_value[self.layer_idx]
|
||||||
|
|
||||||
|
if k is None:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
|
||||||
|
if prev_k is not None:
|
||||||
|
k = torch.cat((prev_k, k), dim=1)
|
||||||
|
v = torch.cat((prev_v, v), dim=1)
|
||||||
|
|
||||||
|
# 更新past
|
||||||
|
if past is not None:
|
||||||
|
if qkv_same:
|
||||||
|
past.decoder_prev_key[self.layer_idx] = k
|
||||||
|
past.decoder_prev_value[self.layer_idx] = v
|
||||||
|
else:
|
||||||
|
past.encoder_key[self.layer_idx] = k
|
||||||
|
past.encoder_value[self.layer_idx] = v
|
||||||
|
|
||||||
|
# 开始计算attention
|
||||||
|
batch_size, q_len, d_model = query.size()
|
||||||
|
k_len, v_len = k.size(1), v.size(1)
|
||||||
|
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim)
|
||||||
|
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim)
|
||||||
|
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim)
|
||||||
|
|
||||||
|
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head
|
||||||
|
if key_mask is not None:
|
||||||
|
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,n_head
|
||||||
|
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf'))
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
_attn_mask = ~attn_mask[None, :, :, None].bool() # 1,q_len,k_len,n_head
|
||||||
|
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf'))
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=2)
|
||||||
|
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
|
||||||
|
output = output.reshape(batch_size, q_len, -1)
|
||||||
|
output = self.out_proj(output) # batch,q_len,dim
|
||||||
|
|
||||||
|
return output, attn_weights
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
|
||||||
|
def set_layer_idx(self, layer_idx):
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
|
||||||
class BiLSTMEncoder(nn.Module):
|
class TransformerSeq2SeqEncoderLayer(nn.Module):
|
||||||
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
|
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048,
|
||||||
|
dropout: float = 0.1):
|
||||||
|
super(TransformerSeq2SeqEncoderLayer, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_head = n_head
|
||||||
|
self.dim_ff = dim_ff
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.self_attn = MultiheadAttention(d_model, n_head, dropout)
|
||||||
|
self.attn_layer_norm = LayerNorm(d_model)
|
||||||
|
self.ffn_layer_norm = LayerNorm(d_model)
|
||||||
|
|
||||||
|
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(self.dim_ff, self.d_model),
|
||||||
|
nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, x, encoder_mask):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param x: batch,src_seq,dim
|
||||||
|
:param encoder_mask: batch,src_seq
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# attention
|
||||||
|
residual = x
|
||||||
|
x = self.attn_layer_norm(x)
|
||||||
|
x, _ = self.self_attn(query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_mask=encoder_mask)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
# ffn
|
||||||
|
residual = x
|
||||||
|
x = self.ffn_layer_norm(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqEncoder(nn.Module):
|
||||||
|
def __init__(self, vocab):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
|
def forward(self, src_words, src_seq_len):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
|
||||||
|
def __init__(self, vocab: Vocabulary, embed: nn.Module, pos_embed: nn.Module = None, num_layers: int = 6,
|
||||||
|
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
|
||||||
|
super(TransformerSeq2SeqEncoder, self).__init__(vocab)
|
||||||
self.embed = embed
|
self.embed = embed
|
||||||
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
|
self.embed_scale = math.sqrt(d_model)
|
||||||
|
self.pos_embed = pos_embed
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_head = n_head
|
||||||
|
self.dim_ff = dim_ff
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout)
|
||||||
|
for _ in range(num_layers)])
|
||||||
|
self.layer_norm = LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, src_words, src_seq_len):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param src_words: batch, src_seq_len
|
||||||
|
:param src_seq_len: [batch]
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
batch_size, max_src_len = src_words.size()
|
||||||
|
device = src_words.device
|
||||||
|
x = self.embed(src_words) * self.embed_scale # batch, seq, dim
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device)
|
||||||
|
x += self.pos_embed(position)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
encoder_mask = seq_len_to_mask(src_seq_len)
|
||||||
|
encoder_mask = encoder_mask.to(device)
|
||||||
|
|
||||||
|
for layer in self.layer_stacks:
|
||||||
|
x = layer(x, encoder_mask)
|
||||||
|
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, encoder_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMSeq2SeqEncoder(Seq2SeqEncoder):
|
||||||
|
def __init__(self, vocab: Vocabulary, embed: nn.Module, num_layers: int = 3, hidden_size: int = 400,
|
||||||
|
dropout: float = 0.3, bidirectional=True):
|
||||||
|
super().__init__(vocab)
|
||||||
|
self.embed = embed
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dropout = dropout
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=bidirectional,
|
||||||
batch_first=True, dropout=dropout, num_layers=num_layers)
|
batch_first=True, dropout=dropout, num_layers=num_layers)
|
||||||
|
|
||||||
def forward(self, words, seq_len):
|
def forward(self, src_words, src_seq_len):
|
||||||
words = self.embed(words)
|
batch_size = src_words.size(0)
|
||||||
words, hx = self.lstm(words, seq_len)
|
device = src_words.device
|
||||||
|
x = self.embed(src_words)
|
||||||
|
x, (final_hidden, final_cell) = self.lstm(x, src_seq_len)
|
||||||
|
encoder_mask = seq_len_to_mask(src_seq_len).to(device)
|
||||||
|
|
||||||
return words, hx
|
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim
|
||||||
|
|
||||||
|
def concat_bidir(input):
|
||||||
|
output = input.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous()
|
||||||
|
return output.view(self.num_layers, batch_size, -1)
|
||||||
|
|
||||||
|
if self.bidirectional:
|
||||||
|
final_hidden = concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input
|
||||||
|
final_cell = concat_bidir(final_cell)
|
||||||
|
|
||||||
|
return (x, (final_hidden, final_cell)), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return
|
||||||
|
@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer
|
|||||||
|
|
||||||
__author__ = "Yu-Hsiang Huang"
|
__author__ = "Yu-Hsiang Huang"
|
||||||
|
|
||||||
|
|
||||||
def get_non_pad_mask(seq):
|
def get_non_pad_mask(seq):
|
||||||
assert seq.dim() == 2
|
assert seq.dim() == 2
|
||||||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
|
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||||
''' Sinusoid position encoding table '''
|
''' Sinusoid position encoding table '''
|
||||||
|
|
||||||
@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|||||||
|
|
||||||
return torch.FloatTensor(sinusoid_table)
|
return torch.FloatTensor(sinusoid_table)
|
||||||
|
|
||||||
|
|
||||||
def get_attn_key_pad_mask(seq_k, seq_q):
|
def get_attn_key_pad_mask(seq_k, seq_q):
|
||||||
''' For masking out the padding part of key sequence. '''
|
''' For masking out the padding part of key sequence. '''
|
||||||
|
|
||||||
@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q):
|
|||||||
|
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
def get_subsequent_mask(seq):
|
def get_subsequent_mask(seq):
|
||||||
''' For masking out the subsequent info. '''
|
''' For masking out the subsequent info. '''
|
||||||
|
|
||||||
@ -51,6 +55,7 @@ def get_subsequent_mask(seq):
|
|||||||
|
|
||||||
return subsequent_mask
|
return subsequent_mask
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
''' A encoder model with self attention mechanism. '''
|
''' A encoder model with self attention mechanism. '''
|
||||||
|
|
||||||
@ -98,6 +103,7 @@ class Encoder(nn.Module):
|
|||||||
return enc_output, enc_slf_attn_list
|
return enc_output, enc_slf_attn_list
|
||||||
return enc_output,
|
return enc_output,
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
''' A decoder model with self attention mechanism. '''
|
''' A decoder model with self attention mechanism. '''
|
||||||
|
|
||||||
@ -152,6 +158,7 @@ class Decoder(nn.Module):
|
|||||||
return dec_output, dec_slf_attn_list, dec_enc_attn_list
|
return dec_output, dec_slf_attn_list, dec_enc_attn_list
|
||||||
return dec_output,
|
return dec_output,
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
''' A sequence to sequence model with attention mechanism. '''
|
''' A sequence to sequence model with attention mechanism. '''
|
||||||
|
|
||||||
@ -181,8 +188,8 @@ class Transformer(nn.Module):
|
|||||||
nn.init.xavier_normal_(self.tgt_word_prj.weight)
|
nn.init.xavier_normal_(self.tgt_word_prj.weight)
|
||||||
|
|
||||||
assert d_model == d_word_vec, \
|
assert d_model == d_word_vec, \
|
||||||
'To facilitate the residual connections, \
|
'To facilitate the residual connections, \
|
||||||
the dimensions of all module outputs shall be the same.'
|
the dimensions of all module outputs shall be the same.'
|
||||||
|
|
||||||
if tgt_emb_prj_weight_sharing:
|
if tgt_emb_prj_weight_sharing:
|
||||||
# Share the weight matrix between target word embedding & the final logit dense layer
|
# Share the weight matrix between target word embedding & the final logit dense layer
|
||||||
@ -194,7 +201,7 @@ class Transformer(nn.Module):
|
|||||||
if emb_src_tgt_weight_sharing:
|
if emb_src_tgt_weight_sharing:
|
||||||
# Share the weight matrix between source & target word embeddings
|
# Share the weight matrix between source & target word embeddings
|
||||||
assert n_src_vocab == n_tgt_vocab, \
|
assert n_src_vocab == n_tgt_vocab, \
|
||||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
|
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
|
||||||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
|
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
|
||||||
|
|
||||||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
|
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
|
||||||
|
@ -2,8 +2,10 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
|
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
|
||||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
|
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \
|
||||||
|
LSTMSeq2SeqDecoder
|
||||||
|
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
|
||||||
from fastNLP import Vocabulary
|
from fastNLP import Vocabulary
|
||||||
from fastNLP.embeddings import StaticEmbedding
|
from fastNLP.embeddings import StaticEmbedding
|
||||||
from fastNLP.core.utils import seq_len_to_mask
|
from fastNLP.core.utils import seq_len_to_mask
|
||||||
@ -15,22 +17,17 @@ class TestTransformerSeq2SeqDecoder(unittest.TestCase):
|
|||||||
vocab.add_word_lst("Another test !".split())
|
vocab.add_word_lst("Another test !".split())
|
||||||
embed = StaticEmbedding(vocab, embedding_dim=512)
|
embed = StaticEmbedding(vocab, embedding_dim=512)
|
||||||
|
|
||||||
encoder = TransformerSeq2SeqEncoder(embed)
|
args = TransformerSeq2SeqModel.add_args()
|
||||||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True)
|
model = TransformerSeq2SeqModel.build_model(args, vocab, vocab)
|
||||||
|
|
||||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
|
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
|
||||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
|
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
|
||||||
src_seq_len = torch.LongTensor([3, 2])
|
src_seq_len = torch.LongTensor([3, 2])
|
||||||
|
|
||||||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
|
output = model(src_words_idx, src_seq_len, tgt_words_idx)
|
||||||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask)
|
print(output)
|
||||||
|
|
||||||
decoder_outputs = decoder(tgt_words_idx, past)
|
# self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
|
||||||
|
|
||||||
print(decoder_outputs)
|
|
||||||
print(mask)
|
|
||||||
|
|
||||||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
|
|
||||||
|
|
||||||
def test_decode(self):
|
def test_decode(self):
|
||||||
pass # todo
|
pass # todo
|
||||||
|
Loading…
Reference in New Issue
Block a user