diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index c6930b9a..440a7bd2 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models """ __all__ = [ "CNNText", - + "SeqLabeling", "AdvSeqLabel", "BiLSTMCRF", - + "ESIM", - + "StarTransEnc", "STSeqLabel", "STNLICls", "STSeqCls", - + "BiaffineParser", "GraphParser", @@ -30,7 +30,9 @@ __all__ = [ "BertForTokenClassification", "BertForQuestionAnswering", - "TransformerSeq2SeqModel" + "TransformerSeq2SeqModel", + "LSTMSeq2SeqModel", + "BaseSeq2SeqModel" ] from .base_model import BaseModel @@ -41,7 +43,8 @@ from .cnn_text_classification import CNNText from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF from .snli import ESIM from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel -from .seq2seq_model import TransformerSeq2SeqModel +from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, BaseSeq2SeqModel import sys from ..doc_utils import doc_process -doc_process(sys.modules[__name__]) \ No newline at end of file + +doc_process(sys.modules[__name__]) diff --git a/fastNLP/models/seq2seq_model.py b/fastNLP/models/seq2seq_model.py index 94b96198..042f8544 100644 --- a/fastNLP/models/seq2seq_model.py +++ b/fastNLP/models/seq2seq_model.py @@ -1,26 +1,153 @@ -import torch.nn as nn import torch -from typing import Union, Tuple +from torch import nn import numpy as np -from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast +from ..embeddings import StaticEmbedding +from ..modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, Seq2SeqEncoder, LSTMSeq2SeqEncoder +from ..modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder +from ..core import Vocabulary +import argparse -class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法 - def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], - tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], - num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, - output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, - bind_input_output_embed=False): - super().__init__() - self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) - self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, - bind_input_output_embed) +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + ''' Sinusoid position encoding table ''' - self.num_layers = num_layers + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) - def forward(self, words, target, seq_len): - encoder_output, encoder_mask = self.encoder(words, seq_len) - past = TransformerPast(encoder_output, encoder_mask, self.num_layers) - outputs = self.decoder(target, past, return_attention=False) + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] - return outputs + sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + + +def build_embedding(vocab, embed_dim, model_dir_or_name=None): + """ + todo: 根据需求可丰富该函数的功能,目前只返回StaticEmbedding + :param vocab: Vocabulary + :param embed_dim: + :param model_dir_or_name: + :return: + """ + assert isinstance(vocab, Vocabulary) + embed = StaticEmbedding(vocab=vocab, embedding_dim=embed_dim, model_dir_or_name=model_dir_or_name) + + return embed + + +class BaseSeq2SeqModel(nn.Module): + def __init__(self, encoder, decoder): + super(BaseSeq2SeqModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + assert isinstance(self.encoder, Seq2SeqEncoder) + assert isinstance(self.decoder, Seq2SeqDecoder) + + def forward(self, src_words, src_seq_len, tgt_prev_words): + encoder_output, encoder_mask = self.encoder(src_words, src_seq_len) + decoder_output = self.decoder(tgt_prev_words, encoder_output, encoder_mask) + + return {'tgt_output': decoder_output} + + +class LSTMSeq2SeqModel(BaseSeq2SeqModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dropout', type=float, default=0.3) + parser.add_argument('--embedding_dim', type=int, default=300) + parser.add_argument('--num_layers', type=int, default=3) + parser.add_argument('--hidden_size', type=int, default=300) + parser.add_argument('--bidirectional', action='store_true', default=True) + args = parser.parse_args() + + return args + + @classmethod + def build_model(cls, args, src_vocab, tgt_vocab): + # 处理embedding + src_embed = build_embedding(src_vocab, args.embedding_dim) + if args.share_embedding: + assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab" + tgt_embed = src_embed + else: + tgt_embed = build_embedding(tgt_vocab, args.embedding_dim) + + if args.bind_input_output_embed: + output_embed = nn.Parameter(tgt_embed.embedding.weight) + else: + output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), args.embedding_dim), requires_grad=True) + nn.init.normal_(output_embed, mean=0, std=args.embedding_dim ** -0.5) + + encoder = LSTMSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, num_layers=args.num_layers, + hidden_size=args.hidden_size, dropout=args.dropout, + bidirectional=args.bidirectional) + decoder = LSTMSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, num_layers=args.num_layers, + hidden_size=args.hidden_size, dropout=args.dropout, output_embed=output_embed, + attention=True) + + return LSTMSeq2SeqModel(encoder, decoder) + + +class TransformerSeq2SeqModel(BaseSeq2SeqModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dropout', type=float, default=0.1) + parser.add_argument('--d_model', type=int, default=512) + parser.add_argument('--num_layers', type=int, default=6) + parser.add_argument('--n_head', type=int, default=8) + parser.add_argument('--dim_ff', type=int, default=2048) + parser.add_argument('--bind_input_output_embed', action='store_true', default=True) + parser.add_argument('--share_embedding', action='store_true', default=True) + + args = parser.parse_args() + + return args + + @classmethod + def build_model(cls, args, src_vocab, tgt_vocab): + d_model = args.d_model + args.max_positions = getattr(args, 'max_positions', 1024) # 处理的最长长度 + + # 处理embedding + src_embed = build_embedding(src_vocab, d_model) + if args.share_embedding: + assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab" + tgt_embed = src_embed + else: + tgt_embed = build_embedding(tgt_vocab, d_model) + + if args.bind_input_output_embed: + output_embed = nn.Parameter(tgt_embed.embedding.weight) + else: + output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), d_model), requires_grad=True) + nn.init.normal_(output_embed, mean=0, std=d_model ** -0.5) + + pos_embed = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(args.max_positions + 1, d_model, padding_idx=0), + freeze=True) # 这里规定0是padding + + encoder = TransformerSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, pos_embed=pos_embed, + num_layers=args.num_layers, d_model=args.d_model, + n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout) + decoder = TransformerSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, pos_embed=pos_embed, + num_layers=args.num_layers, d_model=args.d_model, + n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout, + output_embed=output_embed) + + return TransformerSeq2SeqModel(encoder, decoder) diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index e973379d..ffda1753 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -51,15 +51,17 @@ __all__ = [ 'summary', - "BiLSTMEncoder", "TransformerSeq2SeqEncoder", + "LSTMSeq2SeqEncoder", + "Seq2SeqEncoder", "SequenceGenerator", - "LSTMDecoder", - "LSTMPast", "TransformerSeq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "Seq2SeqDecoder", + "TransformerPast", - "Decoder", + "LSTMPast", "Past" ] diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py index e3bceff0..b4e1b623 100644 --- a/fastNLP/modules/decoder/__init__.py +++ b/fastNLP/modules/decoder/__init__.py @@ -9,13 +9,15 @@ __all__ = [ "allowed_transitions", "SequenceGenerator", - "LSTMDecoder", + "LSTMPast", - "TransformerSeq2SeqDecoder", "TransformerPast", - "Decoder", "Past", + "TransformerSeq2SeqDecoder", + "LSTMSeq2SeqDecoder", + "Seq2SeqDecoder" + ] from .crf import ConditionalRandomField @@ -23,4 +25,5 @@ from .crf import allowed_transitions from .mlp import MLP from .utils import viterbi_decode from .seq2seq_generator import SequenceGenerator -from .seq2seq_decoder import * +from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMPast, TransformerPast, \ + Past diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py index 9a4e27de..c817e8fc 100644 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -1,513 +1,14 @@ -# coding=utf-8 -__all__ = [ - "TransformerPast", - "LSTMPast", - "Past", - "LSTMDecoder", - "TransformerSeq2SeqDecoder", - "Decoder" -] +import torch.nn as nn import torch -from torch import nn -import abc -import torch.nn.functional as F -from ...embeddings import StaticEmbedding -import numpy as np -from typing import Union, Tuple -from ...embeddings.utils import get_embeddings from torch.nn import LayerNorm +from ..encoder.seq2seq_encoder import MultiheadAttention +import torch.nn.functional as F import math - - -# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ -# get_sinusoid_encoding_table # todo: 应该将position embedding移到core - -def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): - ''' Sinusoid position encoding table ''' - - def cal_angle(position, hid_idx): - return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) - - def get_posi_angle_vec(position): - return [cal_angle(position, hid_j) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) - - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - if padding_idx is not None: - # zero vector for padding dimension - sinusoid_table[padding_idx] = 0. - - return torch.FloatTensor(sinusoid_table) - - -class Past: - def __init__(self): - pass - - @abc.abstractmethod - def num_samples(self): - pass - - -class TransformerPast(Past): - def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, - num_decoder_layer: int = 6): - """ - - :param encoder_outputs: (batch,src_seq_len,dim) - :param encoder_mask: (batch,src_seq_len) - :param encoder_key: list of (batch, src_seq_len, dim) - :param encoder_value: - :param decoder_prev_key: - :param decoder_prev_value: - """ - self.encoder_outputs = encoder_outputs - self.encoder_mask = encoder_mask - self.encoder_key = [None] * num_decoder_layer - self.encoder_value = [None] * num_decoder_layer - self.decoder_prev_key = [None] * num_decoder_layer - self.decoder_prev_value = [None] * num_decoder_layer - - def num_samples(self): - if self.encoder_outputs is not None: - return self.encoder_outputs.size(0) - return None - - def _reorder_state(self, state, indices): - if type(state) == torch.Tensor: - state = state.index_select(index=indices, dim=0) - elif type(state) == list: - for i in range(len(state)): - assert state[i] is not None - state[i] = state[i].index_select(index=indices, dim=0) - else: - raise ValueError('State does not support other format') - - return state - - def reorder_past(self, indices: torch.LongTensor): - self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) - self.encoder_mask = self._reorder_state(self.encoder_mask, indices) - self.encoder_key = self._reorder_state(self.encoder_key, indices) - self.encoder_value = self._reorder_state(self.encoder_value, indices) - self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) - self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) - - -class Decoder(nn.Module): - def __init__(self): - super().__init__() - - def reorder_past(self, indices: torch.LongTensor, past: Past) -> Past: - """ - 根据indices中的index,将past的中状态置为正确的顺序 - - :param torch.LongTensor indices: - :param Past past: - :return: - """ - raise NotImplemented - - def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: - """ - 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 - 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 - - :return: - """ - raise NotImplemented - - -class DecoderMultiheadAttention(nn.Module): - """ - Transformer Decoder端的multihead layer - 相比原版的Multihead功能一致,但能够在inference时加速 - 参考fairseq - """ - - def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): - super(DecoderMultiheadAttention, self).__init__() - self.d_model = d_model - self.n_head = n_head - self.dropout = dropout - self.head_dim = d_model // n_head - self.layer_idx = layer_idx - assert d_model % n_head == 0, "d_model should be divisible by n_head" - self.scaling = self.head_dim ** -0.5 - - self.q_proj = nn.Linear(d_model, d_model) - self.k_proj = nn.Linear(d_model, d_model) - self.v_proj = nn.Linear(d_model, d_model) - self.out_proj = nn.Linear(d_model, d_model) - - self.reset_parameters() - - def forward(self, query, key, value, self_attn_mask=None, encoder_attn_mask=None, past=None, inference=False): - """ - - :param query: (batch, seq_len, dim) - :param key: (batch, seq_len, dim) - :param value: (batch, seq_len, dim) - :param self_attn_mask: None or ByteTensor (1, seq_len, seq_len) - :param encoder_attn_mask: (batch, src_len) ByteTensor - :param past: required for now - :param inference: - :return: x和attention weight - """ - if encoder_attn_mask is not None: - assert self_attn_mask is None - assert past is not None, "Past is required for now" - is_encoder_attn = True if encoder_attn_mask is not None else False - - q = self.q_proj(query) # (batch,q_len,dim) - q *= self.scaling - k = v = None - prev_k = prev_v = None - - if inference and is_encoder_attn and past.encoder_key[self.layer_idx] is not None: - k = past.encoder_key[self.layer_idx] # (batch,k_len,dim) - v = past.encoder_value[self.layer_idx] # (batch,v_len,dim) - else: - if inference and not is_encoder_attn and past.decoder_prev_key[self.layer_idx] is not None: - prev_k = past.decoder_prev_key[self.layer_idx] # (batch, seq_len, dim) - prev_v = past.decoder_prev_value[self.layer_idx] - - if k is None: - k = self.k_proj(key) - v = self.v_proj(value) - if prev_k is not None: - k = torch.cat((prev_k, k), dim=1) - v = torch.cat((prev_v, v), dim=1) - - # 更新past - if inference and is_encoder_attn and past.encoder_key[self.layer_idx] is None: - past.encoder_key[self.layer_idx] = k - past.encoder_value[self.layer_idx] = v - if inference and not is_encoder_attn: - past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k - past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v - - batch_size, q_len, d_model = query.size() - k_len, v_len = k.size(1), v.size(1) - q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) - k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) - v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) - - attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head - mask = encoder_attn_mask if is_encoder_attn else self_attn_mask - if mask is not None: - if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len - mask = mask[:, None, :, None] - else: # (1, seq_len, seq_len) - mask = mask[..., None] - _mask = ~mask.bool() - - attn_weights = attn_weights.masked_fill(_mask, float('-inf')) - - attn_weights = F.softmax(attn_weights, dim=2) - attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) - - output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim - output = output.reshape(batch_size, q_len, -1) - output = self.out_proj(output) # batch,q_len,dim - - return output, attn_weights - - def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj.weight) - nn.init.xavier_uniform_(self.k_proj.weight) - nn.init.xavier_uniform_(self.v_proj.weight) - nn.init.xavier_uniform_(self.out_proj.weight) - - -class TransformerSeq2SeqDecoderLayer(nn.Module): - def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, - layer_idx: int = None): - super(TransformerSeq2SeqDecoderLayer, self).__init__() - self.d_model = d_model - self.n_head = n_head - self.dim_ff = dim_ff - self.dropout = dropout - self.layer_idx = layer_idx # 记录layer的层索引,以方便获取past的信息 - - self.self_attn = DecoderMultiheadAttention(d_model, n_head, dropout, layer_idx) - self.self_attn_layer_norm = LayerNorm(d_model) - - self.encoder_attn = DecoderMultiheadAttention(d_model, n_head, dropout, layer_idx) - self.encoder_attn_layer_norm = LayerNorm(d_model) - - self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(self.dim_ff, self.d_model), - nn.Dropout(dropout)) - - self.final_layer_norm = LayerNorm(self.d_model) - - def forward(self, x, encoder_outputs, self_attn_mask=None, encoder_attn_mask=None, past=None, inference=False): - """ - - :param x: (batch, seq_len, dim) - :param encoder_outputs: (batch,src_seq_len,dim) - :param self_attn_mask: - :param encoder_attn_mask: - :param past: - :param inference: - :return: - """ - if inference: - assert past is not None, "Past is required when inference" - - # self attention part - residual = x - x = self.self_attn_layer_norm(x) - x, _ = self.self_attn(query=x, - key=x, - value=x, - self_attn_mask=self_attn_mask, - past=past, - inference=inference) - x = F.dropout(x, p=self.dropout, training=self.training) - x = residual + x - - # encoder attention part - residual = x - x = self.encoder_attn_layer_norm(x) - x, attn_weight = self.encoder_attn(query=x, - key=past.encoder_outputs, - value=past.encoder_outputs, - encoder_attn_mask=past.encoder_mask, - past=past, - inference=inference) - x = F.dropout(x, p=self.dropout, training=self.training) - x = residual + x - - # ffn - residual = x - x = self.final_layer_norm(x) - x = self.ffn(x) - x = residual + x - - return x, attn_weight - - -class TransformerSeq2SeqDecoder(Decoder): - def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers: int = 6, - d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, - output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, - bind_input_output_embed=False): - """ - - :param embed: decoder端输入的embedding - :param num_layers: Transformer Decoder层数 - :param d_model: Transformer参数 - :param n_head: Transformer参数 - :param dim_ff: Transformer参数 - :param dropout: - :param output_embed: 输出embedding - :param bind_input_output_embed: 是否共享输入输出的embedding权重 - """ - super(TransformerSeq2SeqDecoder, self).__init__() - self.token_embed = get_embeddings(embed) - self.dropout = dropout - - self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) - for layer_idx in range(num_layers)]) - - if isinstance(output_embed, int): - output_embed = (output_embed, d_model) - output_embed = get_embeddings(output_embed) - elif output_embed is not None: - assert not bind_input_output_embed, "When `output_embed` is not None, " \ - "`bind_input_output_embed` must be False." - if isinstance(output_embed, StaticEmbedding): - for i in self.token_embed.words_to_words: - assert i == self.token_embed.words_to_words[i], "The index does not match." - output_embed = self.token_embed.embedding.weight - else: - output_embed = get_embeddings(output_embed) - else: - if not bind_input_output_embed: - raise RuntimeError("You have to specify output embedding.") - - # todo: 由于每个模型都有embedding的绑定或其他操作,建议挪到外部函数以减少冗余,可参考fairseq - self.pos_embed = nn.Embedding.from_pretrained( - get_sinusoid_encoding_table(n_position=1024, d_hid=d_model, padding_idx=0), - freeze=True - ) - - if bind_input_output_embed: - assert output_embed is None, "When `bind_input_output_embed=True`, `output_embed` must be None" - if isinstance(self.token_embed, StaticEmbedding): - for i in self.token_embed.words_to_words: - assert i == self.token_embed.words_to_words[i], "The index does not match." - self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True) - else: - if isinstance(output_embed, nn.Embedding): - self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True) - else: - self.output_embed = output_embed.transpose(0, 1) - self.output_hidden_size = self.output_embed.size(0) - - self.embed_scale = math.sqrt(d_model) - - def forward(self, tokens, past, return_attention=False, inference=False): - """ - - :param tokens: torch.LongTensor, tokens: batch_size , decode_len - :param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要 - :param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 - :param return_attention: - :param inference: 是否在inference阶段 - :return: - """ - assert past is not None - batch_size, decode_len = tokens.size() - device = tokens.device - pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long() - - if not inference: - self_attn_mask = self._get_triangle_mask(decode_len) - self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq - else: - self_attn_mask = None - - tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim - pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim - tokens = pos + tokens - if inference: - tokens = tokens[:, -1:, :] - - x = F.dropout(tokens, p=self.dropout, training=self.training) - for layer in self.layer_stacks: - x, attn_weight = layer(x, past.encoder_outputs, self_attn_mask=self_attn_mask, - encoder_attn_mask=past.encoder_mask, past=past, inference=inference) - - output = torch.matmul(x, self.output_embed) - - if return_attention: - return output, attn_weight - return output - - @torch.no_grad() - def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: - """ - # todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? - :param tokens: torch.LongTensor (batch_size,1) - :param past: TransformerPast - :return: - """ - output = self.forward(tokens, past, inference=True) # batch,1,vocab_size - return output.squeeze(1), past - - def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: - past.reorder_past(indices) - return past - - def _get_triangle_mask(self, max_seq_len): - tensor = torch.ones(max_seq_len, max_seq_len) - return torch.tril(tensor).byte() - - -class LSTMPast(Past): - def __init__(self, encode_outputs=None, encode_mask=None, decode_states=None, hx=None): - """ - - :param torch.Tensor encode_outputs: batch_size x max_len x input_size - :param torch.Tensor encode_mask: batch_size x max_len, 与encode_outputs一样大,用以辅助decode的时候attention到正确的 - 词。为1的地方有词 - :param torch.Tensor decode_states: batch_size x decode_len x hidden_size, Decoder中LSTM的输出结果 - :param tuple hx: 包含LSTM所需要的h与c,h: num_layer x batch_size x hidden_size, c: num_layer x batch_size x hidden_size - """ - super().__init__() - self._encode_outputs = encode_outputs - if encode_mask is None: - if encode_outputs is not None: - self._encode_mask = encode_outputs.new_ones(encode_outputs.size(0), encode_outputs.size(1)).eq(1) - else: - self._encode_mask = None - else: - self._encode_mask = encode_mask - self._decode_states = decode_states - self._hx = hx # 包含了hidden和cell - self._attn_states = None # 当LSTM使用了Attention时会用到 - - def num_samples(self): - for tensor in (self.encode_outputs, self.decode_states, self.hx): - if tensor is not None: - if isinstance(tensor, torch.Tensor): - return tensor.size(0) - else: - return tensor[0].size(0) - return None - - def _reorder_past(self, state, indices, dim=0): - if type(state) == torch.Tensor: - state = state.index_select(index=indices, dim=dim) - elif type(state) == tuple: - tmp_list = [] - for i in range(len(state)): - assert state[i] is not None - tmp_list.append(state[i].index_select(index=indices, dim=dim)) - state = tuple(tmp_list) - else: - raise ValueError('State does not support other format') - - return state - - def reorder_past(self, indices: torch.LongTensor): - self.encode_outputs = self._reorder_past(self.encode_outputs, indices) - self.encode_mask = self._reorder_past(self.encode_mask, indices) - self.hx = self._reorder_past(self.hx, indices, 1) - if self.attn_states is not None: - self.attn_states = self._reorder_past(self.attn_states, indices) - - @property - def hx(self): - return self._hx - - @hx.setter - def hx(self, hx): - self._hx = hx - - @property - def encode_outputs(self): - return self._encode_outputs - - @encode_outputs.setter - def encode_outputs(self, value): - self._encode_outputs = value - - @property - def encode_mask(self): - return self._encode_mask - - @encode_mask.setter - def encode_mask(self, value): - self._encode_mask = value - - @property - def decode_states(self): - return self._decode_states - - @decode_states.setter - def decode_states(self, value): - self._decode_states = value - - @property - def attn_states(self): - """ - 表示LSTMDecoder中attention模块的结果,正常情况下不需要手动设置 - :return: - """ - return self._attn_states - - @attn_states.setter - def attn_states(self, value): - self._attn_states = value +from ...embeddings import StaticEmbedding +from ...core import Vocabulary +import abc +import torch +from typing import Union class AttentionLayer(nn.Module): @@ -548,167 +49,408 @@ class AttentionLayer(nn.Module): return x, attn_scores -class LSTMDecoder(Decoder): - def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400, - hidden_size=None, dropout=0, - output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, - bind_input_output_embed=False, - attention=True): - """ - # embed假设是TokenEmbedding, 则没有对应关系(因为可能一个token会对应多个word)?vocab出来的结果是不对的 +# ----- class past ----- # - :param embed: 输入的embedding - :param int num_layers: 使用多少层LSTM - :param int input_size: 输入被encode后的维度 - :param int hidden_size: LSTM中的隐藏层维度 - :param float dropout: 多层LSTM的dropout - :param int output_embed: 输出的词表如何初始化,如果bind_input_output_embed为True,则改值无效 - :param bool bind_input_output_embed: 是否将输入输出的embedding权重使用同一个 - :param bool attention: 是否使用attention对encode之后的内容进行计算 - """ +class Past: + def __init__(self): + pass + @abc.abstractmethod + def num_samples(self): + raise NotImplementedError + + def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): + if type(state) == torch.Tensor: + state = state.index_select(index=indices, dim=dim) + elif type(state) == list: + for i in range(len(state)): + assert state[i] is not None + state[i] = self._reorder_state(state[i], indices, dim) + elif type(state) == tuple: + tmp_list = [] + for i in range(len(state)): + assert state[i] is not None + tmp_list.append(self._reorder_state(state[i], indices, dim)) + + return state + + +class TransformerPast(Past): + def __init__(self, num_decoder_layer: int = 6): super().__init__() - self.token_embed = get_embeddings(embed) - if hidden_size is None: - hidden_size = input_size - self.hidden_size = hidden_size - self.input_size = input_size - if num_layers == 1: - self.lstm = nn.LSTM(self.token_embed.embedding_dim + hidden_size, hidden_size, num_layers=num_layers, - bidirectional=False, batch_first=True) - else: - self.lstm = nn.LSTM(self.token_embed.embedding_dim + hidden_size, hidden_size, num_layers=num_layers, - bidirectional=False, batch_first=True, dropout=dropout) - if input_size != hidden_size: - self.encode_hidden_proj = nn.Linear(input_size, hidden_size) - self.encode_cell_proj = nn.Linear(input_size, hidden_size) - self.dropout_layer = nn.Dropout(p=dropout) + self.encoder_output = None # batch,src_seq,dim + self.encoder_mask = None + self.encoder_key = [None] * num_decoder_layer + self.encoder_value = [None] * num_decoder_layer + self.decoder_prev_key = [None] * num_decoder_layer + self.decoder_prev_value = [None] * num_decoder_layer - if isinstance(output_embed, int): - output_embed = (output_embed, hidden_size) - output_embed = get_embeddings(output_embed) - elif output_embed is not None: - assert not bind_input_output_embed, "When `output_embed` is not None, `bind_input_output_embed` must " \ - "be False." - if isinstance(output_embed, StaticEmbedding): - for i in self.token_embed.words_to_words: - assert i == self.token_embed.words_to_words[i], "The index does not match." - output_embed = self.token_embed.embedding.weight - else: - output_embed = get_embeddings(output_embed) - else: - if not bind_input_output_embed: - raise RuntimeError("You have to specify output embedding.") + def num_samples(self): + if self.encoder_key[0] is not None: + return self.encoder_key[0].size(0) + return None - if bind_input_output_embed: - assert output_embed is None, "When `bind_input_output_embed=True`, `output_embed` must be None" - if isinstance(self.token_embed, StaticEmbedding): - for i in self.token_embed.words_to_words: - assert i == self.token_embed.words_to_words[i], "The index does not match." - self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1)) - self.output_hidden_size = self.token_embed.embedding_dim - else: - if isinstance(output_embed, nn.Embedding): - self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1)) - else: - self.output_embed = output_embed.transpose(0, 1) - self.output_hidden_size = self.output_embed.size(0) + def reorder_past(self, indices: torch.LongTensor): + self.encoder_output = self._reorder_state(self.encoder_output, indices) + self.encoder_mask = self._reorder_state(self.encoder_mask, indices) + self.encoder_key = self._reorder_state(self.encoder_key, indices) + self.encoder_value = self._reorder_state(self.encoder_value, indices) + self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) + self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) - self.ffn = nn.Sequential(nn.Linear(hidden_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, self.output_hidden_size)) - self.num_layers = num_layers - if attention: - self.attention_layer = AttentionLayer(hidden_size, input_size, hidden_size, bias=False) - else: - self.attention_layer = None +class LSTMPast(Past): + def __init__(self): + self.encoder_output = None # batch,src_seq,dim + self.encoder_mask = None + self.prev_hidden = None # n_layer,batch,dim + self.pre_cell = None # n_layer,batch,dim + self.input_feed = None # batch,dim - def _init_hx(self, past, tokens): - batch_size = tokens.size(0) - if past.hx is None: - zeros = tokens.new_zeros((self.num_layers, batch_size, self.hidden_size)).float() - past.hx = (zeros, zeros) - else: - assert past.hx[0].size(-1) == self.input_size - if self.attention_layer is not None: - if past.attn_states is None: - past.attn_states = past.hx[0].new_zeros(batch_size, self.hidden_size) - else: - assert past.attn_states.size(-1) == self.hidden_size, "The attention states dimension mismatch." - if self.hidden_size != past.hx[0].size(-1): - hidden, cell = past.hx - hidden = self.encode_hidden_proj(hidden) - cell = self.encode_cell_proj(cell) - past.hx = (hidden, cell) - return past + def num_samples(self): + if self.prev_hidden is not None: + return self.prev_hidden.size(0) + return None - def forward(self, tokens, past=None, return_attention=False): + def reorder_past(self, indices: torch.LongTensor): + self.encoder_output = self._reorder_state(self.encoder_output, indices) + self.encoder_mask = self._reorder_state(self.encoder_mask, indices) + self.prev_hidden = self._reorder_state(self.prev_hidden, indices, dim=1) + self.pre_cell = self._reorder_state(self.pre_cell, indices, dim=1) + self.input_feed = self._reorder_state(self.input_feed, indices) + + +# ------ # + +class Seq2SeqDecoder(nn.Module): + def __init__(self, vocab): + super().__init__() + self.vocab = vocab + self._past = None + + def forward(self, tgt_prev_words, encoder_output, encoder_mask, past=None, return_attention=False): + raise NotImplementedError + + def init_past(self, *args, **kwargs): + raise NotImplementedError + + def reset_past(self): + self._past = None + + def train(self, mode=True): + self.reset_past() + super().train() + + def reorder_past(self, indices: torch.LongTensor, past: Past = None): """ + 根据indices中的index,将past的中状态置为正确的顺序 - :param torch.LongTensor, tokens: batch_size x decode_len, 应该输入整个句子 - :param LSTMPast past: 应该包含了encode的输出 - :param bool return_attention: 是否返回各处attention的值 + :param torch.LongTensor indices: + :param Past past: :return: """ - batch_size, decode_len = tokens.size() - tokens = self.token_embed(tokens) # b x decode_len x embed_size + raise NotImplemented - past = self._init_hx(past, tokens) - - tokens = self.dropout_layer(tokens) - - decode_states = tokens.new_zeros((batch_size, decode_len, self.hidden_size)) - if self.attention_layer is not None: - attn_scores = tokens.new_zeros((tokens.size(0), tokens.size(1), past.encode_outputs.size(1))) - if self.attention_layer is not None: - input_feed = past.attn_states - else: - input_feed = past.hx[0][-1] - for i in range(tokens.size(1)): - input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' - # bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) - - _, (hidden, cell) = self.lstm(input, hx=past.hx) - past.hx = (hidden, cell) - if self.attention_layer is not None: - input_feed, attn_score = self.attention_layer(hidden[-1], past.encode_outputs, past.encode_mask) - attn_scores[:, i] = attn_score - past.attn_states = input_feed - else: - input_feed = hidden[-1] - decode_states[:, i] = input_feed - - decode_states = self.dropout_layer(decode_states) - - outputs = self.ffn(decode_states) # batch_size x decode_len x output_hidden_size - - feats = torch.matmul(outputs, self.output_embed) # bsz x decode_len x vocab_size - if return_attention: - return feats, attn_scores - else: - return feats + # def decode(self, *args, **kwargs) -> torch.Tensor: + # """ + # 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 + # 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 + # + # :return: + # """ + # raise NotImplemented @torch.no_grad() - def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: + def decode(self, tgt_prev_words, encoder_output, encoder_mask, past=None) -> torch.Tensor: """ - 给定上一个位置的输出,决定当前位置的输出。 - :param torch.LongTensor tokens: batch_size x seq_len - :param LSTMPast past: + :param tgt_prev_words: 传入的是完整的prev tokens + :param encoder_output: + :param encoder_mask: + :param past :return: """ - # past = self._init_hx(past, tokens) - tokens = tokens[:, -1:] - feats = self.forward(tokens, past, return_attention=False) - return feats.squeeze(1), past + if past is None: + past = self._past + assert past is not None + output = self.forward(tgt_prev_words, encoder_output, encoder_mask, past) # batch,1,vocab_size + return output.squeeze(1) - def reorder_past(self, indices: torch.LongTensor, past: LSTMPast) -> LSTMPast: + +class TransformerSeq2SeqDecoderLayer(nn.Module): + def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, + layer_idx: int = None): + super().__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + self.layer_idx = layer_idx # 记录layer的层索引,以方便获取past的信息 + + self.self_attn = MultiheadAttention(d_model, n_head, dropout, layer_idx) + self.self_attn_layer_norm = LayerNorm(d_model) + + self.encoder_attn = MultiheadAttention(d_model, n_head, dropout, layer_idx) + self.encoder_attn_layer_norm = LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + self.final_layer_norm = LayerNorm(self.d_model) + + def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, past=None): """ - 将LSTMPast中的状态重置一下 - :param torch.LongTensor indices: 在batch维度的index - :param LSTMPast past: 保存的过去的状态 + :param x: (batch, seq_len, dim), decoder端的输入 + :param encoder_output: (batch,src_seq_len,dim) + :param encoder_mask: batch,src_seq_len + :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 + :param past: 只在inference阶段传入 :return: """ + + # self attention part + residual = x + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + attn_mask=self_attn_mask, + past=past) + + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # encoder attention part + residual = x + x = self.encoder_attn_layer_norm(x) + x, attn_weight = self.encoder_attn(query=x, + key=encoder_output, + value=encoder_output, + key_mask=encoder_mask, + past=past) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.final_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x, attn_weight + + +class TransformerSeq2SeqDecoder(Seq2SeqDecoder): + def __init__(self, vocab: Vocabulary, embed: nn.Module, pos_embed: nn.Module = None, num_layers: int = 6, + d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, + output_embed: nn.Parameter = None): + """ + + :param embed: decoder端输入的embedding + :param num_layers: Transformer Decoder层数 + :param d_model: Transformer参数 + :param n_head: Transformer参数 + :param dim_ff: Transformer参数 + :param dropout: + :param output_embed: 输出embedding + """ + super().__init__(vocab) + + self.embed = embed + self.pos_embed = pos_embed + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) + for layer_idx in range(num_layers)]) + + self.embed_scale = math.sqrt(d_model) + self.layer_norm = LayerNorm(d_model) + self.output_embed = output_embed # len(vocab), d_model + + def forward(self, tgt_prev_words, encoder_output, encoder_mask, past=None, return_attention=False): + """ + + :param tgt_prev_words: batch, tgt_len + :param encoder_output: batch, src_len, dim + :param encoder_mask: batch, src_seq + :param past: + :param return_attention: + :return: + """ + batch_size, max_tgt_len = tgt_prev_words.size() + device = tgt_prev_words.device + + position = torch.arange(1, max_tgt_len + 1).unsqueeze(0).long().to(device) + if past is not None: # 此时在inference阶段 + position = position[:, -1] + tgt_prev_words = tgt_prev_words[:-1] + + x = self.embed_scale * self.embed(tgt_prev_words) + if self.pos_embed is not None: + x += self.pos_embed(position) + x = F.dropout(x, p=self.dropout, training=self.training) + + if past is None: + triangle_mask = self._get_triangle_mask(max_tgt_len) + triangle_mask = triangle_mask.to(device) + else: + triangle_mask = None + + for layer in self.layer_stacks: + x, attn_weight = layer(x=x, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + self_attn_mask=triangle_mask, + past=past + ) + + x = self.layer_norm(x) # batch, tgt_len, dim + output = F.linear(x, self.output_embed) + + if return_attention: + return output, attn_weight + return output + + def reorder_past(self, indices: torch.LongTensor, past: TransformerPast = None) -> TransformerPast: + if past is None: + past = self._past past.reorder_past(indices) return past + + @property + def past(self): + return self._past + + def init_past(self, encoder_output=None, encoder_mask=None): + self._past = TransformerPast(self.num_layers) + self._past.encoder_output = encoder_output + self._past.encoder_mask = encoder_mask + + @past.setter + def past(self, past): + assert isinstance(past, TransformerPast) + self._past = past + + @staticmethod + def _get_triangle_mask(max_seq_len): + tensor = torch.ones(max_seq_len, max_seq_len) + return torch.tril(tensor).byte() + + +class LSTMSeq2SeqDecoder(Seq2SeqDecoder): + def __init__(self, vocab: Vocabulary, embed: nn.Module, num_layers: int = 3, hidden_size: int = 300, + dropout: float = 0.3, output_embed: nn.Parameter = None, attention=True): + super().__init__(vocab) + + self.embed = embed + self.output_embed = output_embed + self.embed_dim = embed.embedding_dim + self.hidden_size = hidden_size + self.num_layers = num_layers + self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, + batch_first=True, bidirectional=False, dropout=dropout) + self.attention_layer = AttentionLayer(hidden_size, self.embed_dim, hidden_size) if attention else None + assert self.attention_layer is not None, "Attention Layer is required for now" # todo 支持不做attention + self.dropout_layer = nn.Dropout(dropout) + + def forward(self, tgt_prev_words, encoder_output, encoder_mask, past=None, return_attention=False): + """ + + :param tgt_prev_words: batch, tgt_len + :param encoder_output: + output: batch, src_len, dim + (hidden,cell): num_layers, batch, dim + :param encoder_mask: batch, src_seq + :param past: + :param return_attention: + :return: + """ + # input feed就是上一个时间步的最后一层layer的hidden state和out的融合 + + batch_size, max_tgt_len = tgt_prev_words.size() + device = tgt_prev_words.device + src_output, (src_final_hidden, src_final_cell) = encoder_output + if past is not None: + tgt_prev_words = tgt_prev_words[:-1] # 只取最后一个 + + x = self.embed(tgt_prev_words) + x = self.dropout_layer(x) + + attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq + input_feed = None + cur_hidden = None + cur_cell = None + + if past is not None: # 若past存在,则从中获取历史input feed + input_feed = past.input_feed + + if input_feed is None: + input_feed = src_final_hidden[-1] # 以encoder的hidden作为初值, batch, dim + decoder_out = [] + + if past is not None: + cur_hidden = past.prev_hidden + cur_cell = past.prev_cell + + if cur_hidden is None: + cur_hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) + cur_cell = torch.zeros(self.num_layers, batch_size, self.hidden_size) + + # 开始计算 + for i in range(max_tgt_len): + input = torch.cat( + (x[:, i:i + 1, :], + input_feed[:, None, :] + ), + dim=2 + ) # batch,1,2*dim + _, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size + if self.attention_layer is not None: + input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) + attn_weights.append(attn_weight) + else: + input_feed = cur_hidden[-1] + + if past is not None: # 保存状态 + past.input_feed = input_feed # batch, hidden + past.prev_hidden = cur_hidden + past.prev_cell = cur_cell + decoder_out.append(input_feed) + + decoder_out = torch.cat(decoder_out, dim=1) # batch,seq_len,hidden + decoder_out = self.dropout_layer(decoder_out) + if attn_weights is not None: + attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len + + output = F.linear(decoder_out, self.output_embed) + if return_attention: + return output, attn_weights + return output + + def reorder_past(self, indices: torch.LongTensor, past: LSTMPast) -> LSTMPast: + if past is None: + past = self._past + past.reorder_past(indices) + + return past + + def init_past(self, encoder_output=None, encoder_mask=None): + self._past = LSTMPast() + self._past.encoder_output = encoder_output + self._past.encoder_mask = encoder_mask + + @property + def past(self): + return self._past + + @past.setter + def past(self, past): + assert isinstance(past, LSTMPast) + self._past = past diff --git a/fastNLP/modules/decoder/seq2seq_generator.py b/fastNLP/modules/decoder/seq2seq_generator.py index 4ee2c787..d7d0c71f 100644 --- a/fastNLP/modules/decoder/seq2seq_generator.py +++ b/fastNLP/modules/decoder/seq2seq_generator.py @@ -2,23 +2,29 @@ __all__ = [ "SequenceGenerator" ] import torch -from .seq2seq_decoder import Decoder +from ...models.seq2seq_model import BaseSeq2SeqModel +from ..encoder.seq2seq_encoder import Seq2SeqEncoder +from ..decoder.seq2seq_decoder import Seq2SeqDecoder import torch.nn.functional as F from ...core.utils import _get_model_device from functools import partial +from ...core import Vocabulary class SequenceGenerator: - def __init__(self, decoder: Decoder, max_length=20, num_beams=1, + def __init__(self, encoder: Seq2SeqEncoder = None, decoder: Seq2SeqDecoder = None, + max_length=20, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0): if do_sample: - self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, + self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, + num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty) else: - self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, + self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, + num_beams=num_beams, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty) @@ -32,30 +38,45 @@ class SequenceGenerator: self.eos_token_id = eos_token_id self.repetition_penalty = repetition_penalty self.length_penalty = length_penalty + # self.vocab = tgt_vocab + self.encoder = encoder self.decoder = decoder @torch.no_grad() - def generate(self, tokens=None, past=None): + def generate(self, src_tokens: torch.Tensor = None, src_seq_len: torch.Tensor = None, prev_tokens=None): """ - :param torch.LongTensor tokens: batch_size x length, 开始的token - :param past: + :param src_tokens: + :param src_seq_len: + :param prev_tokens: :return: """ - # TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? - return self.generate_func(tokens=tokens, past=past) + if self.encoder is not None: + encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) + else: + encoder_output = encoder_mask = None + + # 每次都初始化past + if encoder_output is not None: + self.decoder.init_past(encoder_output, encoder_mask) + else: + self.decoder.init_past() + return self.generate_func(src_tokens, src_seq_len, prev_tokens) @torch.no_grad() -def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, +def greedy_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None, + prev_tokens=None, max_length=20, num_beams=1, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0): """ 贪婪地搜索句子 - :param Decoder decoder: Decoder对象 - :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 - :param Past past: 应该包好encoder的一些输出。 + + :param decoder: + :param encoder_output: + :param encoder_mask: + :param prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param int max_length: 生成句子的最大长度。 :param int num_beams: 使用多大的beam进行解码。 :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 @@ -65,11 +86,18 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, :return: """ if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, + token_ids = _no_beam_search_generate(decoder=decoder, + encoder_output=encoder_output, encoder_mask=encoder_mask, + prev_tokens=prev_tokens, + max_length=max_length, temperature=1, + top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty) else: - token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, + token_ids = _beam_search_generate(decoder=decoder, + encoder_output=encoder_output, encoder_mask=encoder_mask, + prev_tokens=prev_tokens, max_length=max_length, + num_beams=num_beams, temperature=1, top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty) @@ -78,14 +106,17 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, @torch.no_grad() -def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, +def sample_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None, + prev_tokens=None, max_length=20, num_beams=1, + temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0): """ 使用采样的方法生成句子 - :param Decoder decoder: Decoder对象 - :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 - :param Past past: 应该包好encoder的一些输出。 + :param decoder + :param encoder_output: + :param encoder_mask: + :param torch.LongTensor prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param int max_length: 生成句子的最大长度。 :param int num_beam: 使用多大的beam进行解码。 :param float temperature: 采样时的退火大小 @@ -99,50 +130,55 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, """ # 每个位置在生成的时候会sample生成 if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, + token_ids = _no_beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask, + prev_tokens=prev_tokens, max_length=max_length, + temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty) else: - token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, + token_ids = _beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask, + prev_tokens=prev_tokens, max_length=max_length, + num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty) return token_ids -def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, - top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, +def _no_beam_search_generate(decoder: Seq2SeqDecoder, + encoder_output=None, encoder_mask: torch.Tensor = None, + prev_tokens: torch.Tensor = None, max_length=20, + temperature=1.0, top_k=50, + top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False, repetition_penalty=1.0, length_penalty=1.0): + if encoder_output is not None: + batch_size = encoder_output.size(0) + else: + assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`" + batch_size = prev_tokens.size(0) device = _get_model_device(decoder) - if tokens is None: + + if prev_tokens is None: if bos_token_id is None: - raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") - if past is None: - raise RuntimeError("You have to specify either `past` or `tokens`.") - batch_size = past.num_samples() - if batch_size is None: - raise RuntimeError("Cannot infer the number of samples from `past`.") - tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) - batch_size = tokens.size(0) - if past is not None: - assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." + raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.") + + prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) if eos_token_id is None: _eos_token_id = float('nan') else: _eos_token_id = eos_token_id - for i in range(tokens.size(1)): - scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past + for i in range(prev_tokens.size(1)): # 先过一遍pretoken,做初始化 + decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask) - token_ids = tokens.clone() + token_ids = prev_tokens.clone() # 保存所有生成的token cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) - # tokens = tokens[:, -1:] while cur_len < max_length: - scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past + scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size x vocab_size if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) @@ -171,9 +207,9 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt next_tokens = torch.argmax(scores, dim=-1) # batch_size next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding - tokens = next_tokens.unsqueeze(1) + next_tokens = next_tokens.unsqueeze(1) - token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len + token_ids = torch.cat([token_ids, next_tokens], dim=-1) # batch_size x max_len end_mask = next_tokens.eq(_eos_token_id) dones = dones.__or__(end_mask) @@ -189,29 +225,31 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt return token_ids -def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, +def _beam_search_generate(decoder: Seq2SeqDecoder, + encoder_output=None, encoder_mask: torch.Tensor = None, + prev_tokens: torch.Tensor = None, max_length=20, num_beams=4, temperature=1.0, top_k=50, - top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, + top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False, repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor: # 进行beam search - device = _get_model_device(decoder) - if tokens is None: - if bos_token_id is None: - raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") - if past is None: - raise RuntimeError("You have to specify either `past` or `tokens`.") - batch_size = past.num_samples() - if batch_size is None: - raise RuntimeError("Cannot infer the number of samples from `past`.") - tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) - batch_size = tokens.size(0) - if past is not None: - assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." - for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode - scores, past = decoder.decode(tokens[:, :i + 1], - past) # (batch_size, vocab_size), Past - scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 + if encoder_output is not None: + batch_size = encoder_output.size(0) + else: + assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`" + batch_size = prev_tokens.size(0) + + device = _get_model_device(decoder) + + if prev_tokens is None: + if bos_token_id is None: + raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.") + + prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) + + for i in range(prev_tokens.size(1)): # 如果输入的长度较长,先decode + scores = decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask) + vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." @@ -225,15 +263,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 # 得到(batch_size, num_beams), (batch_size, num_beams) next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + # 根据index来做顺序的调转 indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) - decoder.reorder_past(indices, past) + decoder.reorder_past(indices) + prev_tokens = prev_tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length - tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length # 记录生成好的token (batch_size', cur_len) - token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) + token_ids = torch.cat([prev_tokens, next_tokens.view(-1, 1)], dim=-1) dones = [False] * batch_size - tokens = next_tokens.view(-1, 1) beam_scores = next_scores.view(-1) # batch_size * num_beams @@ -247,7 +285,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < max_length: - scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past + scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size * num_beams x vocab_size if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) @@ -300,9 +338,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) beam_scores = _next_scores.view(-1) - # 更改past状态, 重组token_ids + # 重组past/encoder状态, 重组token_ids reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 - decoder.reorder_past(reorder_inds, past) + decoder.reorder_past(reorder_inds) flag = True if cur_len + 1 == max_length: @@ -327,8 +365,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) # 重新组织token_ids的状态 - tokens = _next_tokens - token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) + cur_tokens = _next_tokens + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), cur_tokens], dim=-1) for batch_idx in range(batch_size): dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) @@ -436,38 +474,3 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits - - -if __name__ == '__main__': - # TODO 需要检查一下greedy_generate和sample_generate是否正常工作。 - from torch import nn - - - class DummyDecoder(nn.Module): - def __init__(self, num_words): - super().__init__() - self.num_words = num_words - - def decode(self, tokens, past): - batch_size = tokens.size(0) - return torch.randn(batch_size, self.num_words), past - - def reorder_past(self, indices, past): - return past - - - num_words = 10 - batch_size = 3 - decoder = DummyDecoder(num_words) - - tokens = greedy_generate(decoder=decoder, tokens=torch.zeros(batch_size, 1).long(), past=None, max_length=20, - num_beams=2, - bos_token_id=0, eos_token_id=num_words - 1, - repetition_penalty=1, length_penalty=1.0) - print(tokens) - - tokens = sample_generate(decoder, tokens=torch.zeros(batch_size, 1).long(), - past=None, max_length=20, num_beams=2, temperature=1.0, top_k=50, - top_p=1.0, bos_token_id=0, eos_token_id=num_words - 1, repetition_penalty=1.0, - length_penalty=1.0) - print(tokens) diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 579dddd4..1f6a8003 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -31,8 +31,9 @@ __all__ = [ "BiAttention", "SelfAttention", - "BiLSTMEncoder", - "TransformerSeq2SeqEncoder" + "LSTMSeq2SeqEncoder", + "TransformerSeq2SeqEncoder", + "Seq2SeqEncoder" ] from .attention import MultiHeadAttention, BiAttention, SelfAttention @@ -45,4 +46,4 @@ from .star_transformer import StarTransformer from .transformer import TransformerEncoder from .variational_rnn import VarRNN, VarLSTM, VarGRU -from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder +from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py index 1474c864..994003e6 100644 --- a/fastNLP/modules/encoder/seq2seq_encoder.py +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -1,48 +1,238 @@ -__all__ = [ - "TransformerSeq2SeqEncoder", - "BiLSTMEncoder" -] - -from torch import nn +import torch.nn as nn import torch -from ...modules.encoder import LSTM -from ...core.utils import seq_len_to_mask -from torch.nn import TransformerEncoder +from torch.nn import LayerNorm +import torch.nn.functional as F from typing import Union, Tuple -import numpy as np +from ...core.utils import seq_len_to_mask +import math +from ...core import Vocabulary +from ...modules import LSTM -class TransformerSeq2SeqEncoder(nn.Module): - def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers: int = 6, - d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1): - super(TransformerSeq2SeqEncoder, self).__init__() - self.embed = embed - self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head,dim_ff,dropout), num_layers) +class MultiheadAttention(nn.Module): # todo 这个要放哪里? + def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): + super(MultiheadAttention, self).__init__() + self.d_model = d_model + self.n_head = n_head + self.dropout = dropout + self.head_dim = d_model // n_head + self.layer_idx = layer_idx + assert d_model % n_head == 0, "d_model should be divisible by n_head" + self.scaling = self.head_dim ** -0.5 - def forward(self, words, seq_len): + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + self.out_proj = nn.Linear(d_model, d_model) + + self.reset_parameters() + + def forward(self, query, key, value, key_mask=None, attn_mask=None, past=None): """ - :param words: batch, seq_len - :param seq_len: - :return: output: (batch, seq_len,dim) ; encoder_mask + :param query: batch x seq x dim + :param key: + :param value: + :param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 + :param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 + :param past: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 + :return: """ - words = self.embed(words) # batch, seq_len, dim - words = words.transpose(0, 1) - encoder_mask = seq_len_to_mask(seq_len) # batch, seq - words = self.transformer(words, src_key_padding_mask=~encoder_mask) # seq_len,batch,dim + assert key.size() == value.size() + if past is not None: + assert self.layer_idx is not None + qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() - return words.transpose(0, 1), encoder_mask + q = self.q_proj(query) # batch x seq x dim + q *= self.scaling + k = v = None + prev_k = prev_v = None + + # 从past中取kv + if past is not None: # 说明此时在inference阶段 + if qkv_same: # 此时在decoder self attention + prev_k = past.decoder_prev_key[self.layer_idx] + prev_v = past.decoder_prev_value[self.layer_idx] + else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 + k = past.encoder_key[self.layer_idx] + v = past.encoder_value[self.layer_idx] + + if k is None: + k = self.k_proj(key) + v = self.v_proj(value) + + if prev_k is not None: + k = torch.cat((prev_k, k), dim=1) + v = torch.cat((prev_v, v), dim=1) + + # 更新past + if past is not None: + if qkv_same: + past.decoder_prev_key[self.layer_idx] = k + past.decoder_prev_value[self.layer_idx] = v + else: + past.encoder_key[self.layer_idx] = k + past.encoder_value[self.layer_idx] = v + + # 开始计算attention + batch_size, q_len, d_model = query.size() + k_len, v_len = k.size(1), v.size(1) + q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) + k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) + v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) + + attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head + if key_mask is not None: + _key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,n_head + attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) + + if attn_mask is not None: + _attn_mask = ~attn_mask[None, :, :, None].bool() # 1,q_len,k_len,n_head + attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) + + attn_weights = F.softmax(attn_weights, dim=2) + attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + + output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim + output = output.reshape(batch_size, q_len, -1) + output = self.out_proj(output) # batch,q_len,dim + + return output, attn_weights + + def reset_parameters(self): + nn.init.xavier_uniform_(self.q_proj.weight) + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) + + def set_layer_idx(self, layer_idx): + self.layer_idx = layer_idx -class BiLSTMEncoder(nn.Module): - def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3): +class TransformerSeq2SeqEncoderLayer(nn.Module): + def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, + dropout: float = 0.1): + super(TransformerSeq2SeqEncoderLayer, self).__init__() + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.self_attn = MultiheadAttention(d_model, n_head, dropout) + self.attn_layer_norm = LayerNorm(d_model) + self.ffn_layer_norm = LayerNorm(d_model) + + self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(self.dim_ff, self.d_model), + nn.Dropout(dropout)) + + def forward(self, x, encoder_mask): + """ + + :param x: batch,src_seq,dim + :param encoder_mask: batch,src_seq + :return: + """ + # attention + residual = x + x = self.attn_layer_norm(x) + x, _ = self.self_attn(query=x, + key=x, + value=x, + key_mask=encoder_mask) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + # ffn + residual = x + x = self.ffn_layer_norm(x) + x = self.ffn(x) + x = residual + x + + return x + + +class Seq2SeqEncoder(nn.Module): + def __init__(self, vocab): super().__init__() + self.vocab = vocab + + def forward(self, src_words, src_seq_len): + raise NotImplementedError + + +class TransformerSeq2SeqEncoder(Seq2SeqEncoder): + def __init__(self, vocab: Vocabulary, embed: nn.Module, pos_embed: nn.Module = None, num_layers: int = 6, + d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1): + super(TransformerSeq2SeqEncoder, self).__init__(vocab) self.embed = embed - self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True, + self.embed_scale = math.sqrt(d_model) + self.pos_embed = pos_embed + self.num_layers = num_layers + self.d_model = d_model + self.n_head = n_head + self.dim_ff = dim_ff + self.dropout = dropout + + self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) + for _ in range(num_layers)]) + self.layer_norm = LayerNorm(d_model) + + def forward(self, src_words, src_seq_len): + """ + + :param src_words: batch, src_seq_len + :param src_seq_len: [batch] + :return: + """ + batch_size, max_src_len = src_words.size() + device = src_words.device + x = self.embed(src_words) * self.embed_scale # batch, seq, dim + if self.pos_embed is not None: + position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) + x += self.pos_embed(position) + x = F.dropout(x, p=self.dropout, training=self.training) + + encoder_mask = seq_len_to_mask(src_seq_len) + encoder_mask = encoder_mask.to(device) + + for layer in self.layer_stacks: + x = layer(x, encoder_mask) + + x = self.layer_norm(x) + + return x, encoder_mask + + +class LSTMSeq2SeqEncoder(Seq2SeqEncoder): + def __init__(self, vocab: Vocabulary, embed: nn.Module, num_layers: int = 3, hidden_size: int = 400, + dropout: float = 0.3, bidirectional=True): + super().__init__(vocab) + self.embed = embed + self.num_layers = num_layers + self.dropout = dropout + self.hidden_size = hidden_size + self.bidirectional = bidirectional + self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=bidirectional, batch_first=True, dropout=dropout, num_layers=num_layers) - def forward(self, words, seq_len): - words = self.embed(words) - words, hx = self.lstm(words, seq_len) + def forward(self, src_words, src_seq_len): + batch_size = src_words.size(0) + device = src_words.device + x = self.embed(src_words) + x, (final_hidden, final_cell) = self.lstm(x, src_seq_len) + encoder_mask = seq_len_to_mask(src_seq_len).to(device) - return words, hx + # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim + + def concat_bidir(input): + output = input.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous() + return output.view(self.num_layers, batch_size, -1) + + if self.bidirectional: + final_hidden = concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input + final_cell = concat_bidir(final_cell) + + return (x, (final_hidden, final_cell)), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return diff --git a/reproduction/Summarization/Baseline/transformer/Models.py b/reproduction/Summarization/Baseline/transformer/Models.py index d323e785..2d928f96 100644 --- a/reproduction/Summarization/Baseline/transformer/Models.py +++ b/reproduction/Summarization/Baseline/transformer/Models.py @@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer __author__ = "Yu-Hsiang Huang" + def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) + def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): ''' Sinusoid position encoding table ''' @@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): return torch.FloatTensor(sinusoid_table) + def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' @@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q): return padding_mask + def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' @@ -51,6 +55,7 @@ def get_subsequent_mask(seq): return subsequent_mask + class Encoder(nn.Module): ''' A encoder model with self attention mechanism. ''' @@ -98,6 +103,7 @@ class Encoder(nn.Module): return enc_output, enc_slf_attn_list return enc_output, + class Decoder(nn.Module): ''' A decoder model with self attention mechanism. ''' @@ -152,6 +158,7 @@ class Decoder(nn.Module): return dec_output, dec_slf_attn_list, dec_enc_attn_list return dec_output, + class Transformer(nn.Module): ''' A sequence to sequence model with attention mechanism. ''' @@ -181,8 +188,8 @@ class Transformer(nn.Module): nn.init.xavier_normal_(self.tgt_word_prj.weight) assert d_model == d_word_vec, \ - 'To facilitate the residual connections, \ - the dimensions of all module outputs shall be the same.' + 'To facilitate the residual connections, \ + the dimensions of all module outputs shall be the same.' if tgt_emb_prj_weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer @@ -194,7 +201,7 @@ class Transformer(nn.Module): if emb_src_tgt_weight_sharing: # Share the weight matrix between source & target word embeddings assert n_src_vocab == n_tgt_vocab, \ - "To share word embedding table, the vocabulary size of src/tgt shall be the same." + "To share word embedding table, the vocabulary size of src/tgt shall be the same." self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): diff --git a/test/modules/decoder/test_seq2seq_decoder.py b/test/modules/decoder/test_seq2seq_decoder.py index 6c74d527..cd9502bc 100644 --- a/test/modules/decoder/test_seq2seq_decoder.py +++ b/test/modules/decoder/test_seq2seq_decoder.py @@ -2,8 +2,10 @@ import unittest import torch -from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder -from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder +from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \ + LSTMSeq2SeqDecoder +from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel from fastNLP import Vocabulary from fastNLP.embeddings import StaticEmbedding from fastNLP.core.utils import seq_len_to_mask @@ -15,22 +17,17 @@ class TestTransformerSeq2SeqDecoder(unittest.TestCase): vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, embedding_dim=512) - encoder = TransformerSeq2SeqEncoder(embed) - decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True) + args = TransformerSeq2SeqModel.add_args() + model = TransformerSeq2SeqModel.build_model(args, vocab, vocab) src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) src_seq_len = torch.LongTensor([3, 2]) - encoder_outputs, mask = encoder(src_words_idx, src_seq_len) - past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask) + output = model(src_words_idx, src_seq_len, tgt_words_idx) + print(output) - decoder_outputs = decoder(tgt_words_idx, past) - - print(decoder_outputs) - print(mask) - - self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) + # self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) def test_decode(self): pass # todo