mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
基本完成seq2seq基础功能
This commit is contained in:
parent
15360e9724
commit
b95aa56afb
@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models
|
||||
"""
|
||||
__all__ = [
|
||||
"CNNText",
|
||||
|
||||
|
||||
"SeqLabeling",
|
||||
"AdvSeqLabel",
|
||||
"BiLSTMCRF",
|
||||
|
||||
|
||||
"ESIM",
|
||||
|
||||
|
||||
"StarTransEnc",
|
||||
"STSeqLabel",
|
||||
"STNLICls",
|
||||
"STSeqCls",
|
||||
|
||||
|
||||
"BiaffineParser",
|
||||
"GraphParser",
|
||||
|
||||
@ -30,7 +30,9 @@ __all__ = [
|
||||
"BertForTokenClassification",
|
||||
"BertForQuestionAnswering",
|
||||
|
||||
"TransformerSeq2SeqModel"
|
||||
"TransformerSeq2SeqModel",
|
||||
"LSTMSeq2SeqModel",
|
||||
"BaseSeq2SeqModel"
|
||||
]
|
||||
|
||||
from .base_model import BaseModel
|
||||
@ -41,7 +43,8 @@ from .cnn_text_classification import CNNText
|
||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
|
||||
from .snli import ESIM
|
||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
|
||||
from .seq2seq_model import TransformerSeq2SeqModel
|
||||
from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, BaseSeq2SeqModel
|
||||
import sys
|
||||
from ..doc_utils import doc_process
|
||||
doc_process(sys.modules[__name__])
|
||||
|
||||
doc_process(sys.modules[__name__])
|
||||
|
@ -1,26 +1,153 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from typing import Union, Tuple
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast
|
||||
from ..embeddings import StaticEmbedding
|
||||
from ..modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, Seq2SeqEncoder, LSTMSeq2SeqEncoder
|
||||
from ..modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder
|
||||
from ..core import Vocabulary
|
||||
import argparse
|
||||
|
||||
|
||||
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法
|
||||
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
|
||||
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
|
||||
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
|
||||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
|
||||
bind_input_output_embed=False):
|
||||
super().__init__()
|
||||
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
|
||||
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
|
||||
bind_input_output_embed)
|
||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||
''' Sinusoid position encoding table '''
|
||||
|
||||
self.num_layers = num_layers
|
||||
def cal_angle(position, hid_idx):
|
||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
|
||||
|
||||
def forward(self, words, target, seq_len):
|
||||
encoder_output, encoder_mask = self.encoder(words, seq_len)
|
||||
past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
|
||||
outputs = self.decoder(target, past, return_attention=False)
|
||||
def get_posi_angle_vec(position):
|
||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
|
||||
|
||||
return outputs
|
||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
if padding_idx is not None:
|
||||
# zero vector for padding dimension
|
||||
sinusoid_table[padding_idx] = 0.
|
||||
|
||||
return torch.FloatTensor(sinusoid_table)
|
||||
|
||||
|
||||
def build_embedding(vocab, embed_dim, model_dir_or_name=None):
|
||||
"""
|
||||
todo: 根据需求可丰富该函数的功能,目前只返回StaticEmbedding
|
||||
:param vocab: Vocabulary
|
||||
:param embed_dim:
|
||||
:param model_dir_or_name:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(vocab, Vocabulary)
|
||||
embed = StaticEmbedding(vocab=vocab, embedding_dim=embed_dim, model_dir_or_name=model_dir_or_name)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class BaseSeq2SeqModel(nn.Module):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(BaseSeq2SeqModel, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
assert isinstance(self.encoder, Seq2SeqEncoder)
|
||||
assert isinstance(self.decoder, Seq2SeqDecoder)
|
||||
|
||||
def forward(self, src_words, src_seq_len, tgt_prev_words):
|
||||
encoder_output, encoder_mask = self.encoder(src_words, src_seq_len)
|
||||
decoder_output = self.decoder(tgt_prev_words, encoder_output, encoder_mask)
|
||||
|
||||
return {'tgt_output': decoder_output}
|
||||
|
||||
|
||||
class LSTMSeq2SeqModel(BaseSeq2SeqModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dropout', type=float, default=0.3)
|
||||
parser.add_argument('--embedding_dim', type=int, default=300)
|
||||
parser.add_argument('--num_layers', type=int, default=3)
|
||||
parser.add_argument('--hidden_size', type=int, default=300)
|
||||
parser.add_argument('--bidirectional', action='store_true', default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, src_vocab, tgt_vocab):
|
||||
# 处理embedding
|
||||
src_embed = build_embedding(src_vocab, args.embedding_dim)
|
||||
if args.share_embedding:
|
||||
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
|
||||
tgt_embed = src_embed
|
||||
else:
|
||||
tgt_embed = build_embedding(tgt_vocab, args.embedding_dim)
|
||||
|
||||
if args.bind_input_output_embed:
|
||||
output_embed = nn.Parameter(tgt_embed.embedding.weight)
|
||||
else:
|
||||
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), args.embedding_dim), requires_grad=True)
|
||||
nn.init.normal_(output_embed, mean=0, std=args.embedding_dim ** -0.5)
|
||||
|
||||
encoder = LSTMSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size, dropout=args.dropout,
|
||||
bidirectional=args.bidirectional)
|
||||
decoder = LSTMSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size, dropout=args.dropout, output_embed=output_embed,
|
||||
attention=True)
|
||||
|
||||
return LSTMSeq2SeqModel(encoder, decoder)
|
||||
|
||||
|
||||
class TransformerSeq2SeqModel(BaseSeq2SeqModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dropout', type=float, default=0.1)
|
||||
parser.add_argument('--d_model', type=int, default=512)
|
||||
parser.add_argument('--num_layers', type=int, default=6)
|
||||
parser.add_argument('--n_head', type=int, default=8)
|
||||
parser.add_argument('--dim_ff', type=int, default=2048)
|
||||
parser.add_argument('--bind_input_output_embed', action='store_true', default=True)
|
||||
parser.add_argument('--share_embedding', action='store_true', default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, src_vocab, tgt_vocab):
|
||||
d_model = args.d_model
|
||||
args.max_positions = getattr(args, 'max_positions', 1024) # 处理的最长长度
|
||||
|
||||
# 处理embedding
|
||||
src_embed = build_embedding(src_vocab, d_model)
|
||||
if args.share_embedding:
|
||||
assert src_vocab == tgt_vocab, "share_embedding requires a joined vocab"
|
||||
tgt_embed = src_embed
|
||||
else:
|
||||
tgt_embed = build_embedding(tgt_vocab, d_model)
|
||||
|
||||
if args.bind_input_output_embed:
|
||||
output_embed = nn.Parameter(tgt_embed.embedding.weight)
|
||||
else:
|
||||
output_embed = nn.Parameter(torch.Tensor(len(tgt_vocab), d_model), requires_grad=True)
|
||||
nn.init.normal_(output_embed, mean=0, std=d_model ** -0.5)
|
||||
|
||||
pos_embed = nn.Embedding.from_pretrained(
|
||||
get_sinusoid_encoding_table(args.max_positions + 1, d_model, padding_idx=0),
|
||||
freeze=True) # 这里规定0是padding
|
||||
|
||||
encoder = TransformerSeq2SeqEncoder(vocab=src_vocab, embed=src_embed, pos_embed=pos_embed,
|
||||
num_layers=args.num_layers, d_model=args.d_model,
|
||||
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout)
|
||||
decoder = TransformerSeq2SeqDecoder(vocab=tgt_vocab, embed=tgt_embed, pos_embed=pos_embed,
|
||||
num_layers=args.num_layers, d_model=args.d_model,
|
||||
n_head=args.n_head, dim_ff=args.dim_ff, dropout=args.dropout,
|
||||
output_embed=output_embed)
|
||||
|
||||
return TransformerSeq2SeqModel(encoder, decoder)
|
||||
|
@ -51,15 +51,17 @@ __all__ = [
|
||||
|
||||
'summary',
|
||||
|
||||
"BiLSTMEncoder",
|
||||
"TransformerSeq2SeqEncoder",
|
||||
"LSTMSeq2SeqEncoder",
|
||||
"Seq2SeqEncoder",
|
||||
|
||||
"SequenceGenerator",
|
||||
"LSTMDecoder",
|
||||
"LSTMPast",
|
||||
"TransformerSeq2SeqDecoder",
|
||||
"LSTMSeq2SeqDecoder",
|
||||
"Seq2SeqDecoder",
|
||||
|
||||
"TransformerPast",
|
||||
"Decoder",
|
||||
"LSTMPast",
|
||||
"Past"
|
||||
|
||||
]
|
||||
|
@ -9,13 +9,15 @@ __all__ = [
|
||||
"allowed_transitions",
|
||||
|
||||
"SequenceGenerator",
|
||||
"LSTMDecoder",
|
||||
|
||||
"LSTMPast",
|
||||
"TransformerSeq2SeqDecoder",
|
||||
"TransformerPast",
|
||||
"Decoder",
|
||||
"Past",
|
||||
|
||||
"TransformerSeq2SeqDecoder",
|
||||
"LSTMSeq2SeqDecoder",
|
||||
"Seq2SeqDecoder"
|
||||
|
||||
]
|
||||
|
||||
from .crf import ConditionalRandomField
|
||||
@ -23,4 +25,5 @@ from .crf import allowed_transitions
|
||||
from .mlp import MLP
|
||||
from .utils import viterbi_decode
|
||||
from .seq2seq_generator import SequenceGenerator
|
||||
from .seq2seq_decoder import *
|
||||
from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMPast, TransformerPast, \
|
||||
Past
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,23 +2,29 @@ __all__ = [
|
||||
"SequenceGenerator"
|
||||
]
|
||||
import torch
|
||||
from .seq2seq_decoder import Decoder
|
||||
from ...models.seq2seq_model import BaseSeq2SeqModel
|
||||
from ..encoder.seq2seq_encoder import Seq2SeqEncoder
|
||||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder
|
||||
import torch.nn.functional as F
|
||||
from ...core.utils import _get_model_device
|
||||
from functools import partial
|
||||
from ...core import Vocabulary
|
||||
|
||||
|
||||
class SequenceGenerator:
|
||||
def __init__(self, decoder: Decoder, max_length=20, num_beams=1,
|
||||
def __init__(self, encoder: Seq2SeqEncoder = None, decoder: Seq2SeqDecoder = None,
|
||||
max_length=20, num_beams=1,
|
||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
|
||||
repetition_penalty=1, length_penalty=1.0):
|
||||
if do_sample:
|
||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
|
||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty)
|
||||
else:
|
||||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
|
||||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty)
|
||||
@ -32,30 +38,45 @@ class SequenceGenerator:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.length_penalty = length_penalty
|
||||
# self.vocab = tgt_vocab
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, tokens=None, past=None):
|
||||
def generate(self, src_tokens: torch.Tensor = None, src_seq_len: torch.Tensor = None, prev_tokens=None):
|
||||
"""
|
||||
|
||||
:param torch.LongTensor tokens: batch_size x length, 开始的token
|
||||
:param past:
|
||||
:param src_tokens:
|
||||
:param src_seq_len:
|
||||
:param prev_tokens:
|
||||
:return:
|
||||
"""
|
||||
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode?
|
||||
return self.generate_func(tokens=tokens, past=past)
|
||||
if self.encoder is not None:
|
||||
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len)
|
||||
else:
|
||||
encoder_output = encoder_mask = None
|
||||
|
||||
# 每次都初始化past
|
||||
if encoder_output is not None:
|
||||
self.decoder.init_past(encoder_output, encoder_mask)
|
||||
else:
|
||||
self.decoder.init_past()
|
||||
return self.generate_func(src_tokens, src_seq_len, prev_tokens)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
||||
def greedy_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
|
||||
prev_tokens=None, max_length=20, num_beams=1,
|
||||
bos_token_id=None, eos_token_id=None,
|
||||
repetition_penalty=1, length_penalty=1.0):
|
||||
"""
|
||||
贪婪地搜索句子
|
||||
|
||||
:param Decoder decoder: Decoder对象
|
||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||
:param Past past: 应该包好encoder的一些输出。
|
||||
|
||||
:param decoder:
|
||||
:param encoder_output:
|
||||
:param encoder_mask:
|
||||
:param prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||
:param int max_length: 生成句子的最大长度。
|
||||
:param int num_beams: 使用多大的beam进行解码。
|
||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
|
||||
@ -65,11 +86,18 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
||||
:return:
|
||||
"""
|
||||
if num_beams == 1:
|
||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1,
|
||||
token_ids = _no_beam_search_generate(decoder=decoder,
|
||||
encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||
prev_tokens=prev_tokens,
|
||||
max_length=max_length, temperature=1,
|
||||
top_k=50, top_p=1,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||
else:
|
||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
|
||||
token_ids = _beam_search_generate(decoder=decoder,
|
||||
encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||
prev_tokens=prev_tokens, max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
temperature=1, top_k=50, top_p=1,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||
@ -78,14 +106,17 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
|
||||
def sample_generate(decoder: Seq2SeqDecoder, encoder_output=None, encoder_mask=None,
|
||||
prev_tokens=None, max_length=20, num_beams=1,
|
||||
temperature=1.0, top_k=50,
|
||||
top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0):
|
||||
"""
|
||||
使用采样的方法生成句子
|
||||
|
||||
:param Decoder decoder: Decoder对象
|
||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||
:param Past past: 应该包好encoder的一些输出。
|
||||
:param decoder
|
||||
:param encoder_output:
|
||||
:param encoder_mask:
|
||||
:param torch.LongTensor prev_tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
|
||||
:param int max_length: 生成句子的最大长度。
|
||||
:param int num_beam: 使用多大的beam进行解码。
|
||||
:param float temperature: 采样时的退火大小
|
||||
@ -99,50 +130,55 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
|
||||
"""
|
||||
# 每个位置在生成的时候会sample生成
|
||||
if num_beams == 1:
|
||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature,
|
||||
token_ids = _no_beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||
prev_tokens=prev_tokens, max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k, top_p=top_p,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||
else:
|
||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
|
||||
token_ids = _beam_search_generate(decoder=decoder, encoder_output=encoder_output, encoder_mask=encoder_mask,
|
||||
prev_tokens=prev_tokens, max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature, top_k=top_k, top_p=top_p,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
|
||||
return token_ids
|
||||
|
||||
|
||||
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50,
|
||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
|
||||
def _no_beam_search_generate(decoder: Seq2SeqDecoder,
|
||||
encoder_output=None, encoder_mask: torch.Tensor = None,
|
||||
prev_tokens: torch.Tensor = None, max_length=20,
|
||||
temperature=1.0, top_k=50,
|
||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
|
||||
repetition_penalty=1.0, length_penalty=1.0):
|
||||
if encoder_output is not None:
|
||||
batch_size = encoder_output.size(0)
|
||||
else:
|
||||
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
|
||||
batch_size = prev_tokens.size(0)
|
||||
device = _get_model_device(decoder)
|
||||
if tokens is None:
|
||||
|
||||
if prev_tokens is None:
|
||||
if bos_token_id is None:
|
||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
||||
if past is None:
|
||||
raise RuntimeError("You have to specify either `past` or `tokens`.")
|
||||
batch_size = past.num_samples()
|
||||
if batch_size is None:
|
||||
raise RuntimeError("Cannot infer the number of samples from `past`.")
|
||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
batch_size = tokens.size(0)
|
||||
if past is not None:
|
||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
|
||||
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
|
||||
|
||||
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = float('nan')
|
||||
else:
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
for i in range(tokens.size(1)):
|
||||
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
|
||||
for i in range(prev_tokens.size(1)): # 先过一遍pretoken,做初始化
|
||||
decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
|
||||
|
||||
token_ids = tokens.clone()
|
||||
token_ids = prev_tokens.clone() # 保存所有生成的token
|
||||
cur_len = token_ids.size(1)
|
||||
dones = token_ids.new_zeros(batch_size).eq(1)
|
||||
# tokens = tokens[:, -1:]
|
||||
|
||||
while cur_len < max_length:
|
||||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past
|
||||
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size x vocab_size
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
@ -171,9 +207,9 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
|
||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size
|
||||
|
||||
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
|
||||
tokens = next_tokens.unsqueeze(1)
|
||||
next_tokens = next_tokens.unsqueeze(1)
|
||||
|
||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
||||
token_ids = torch.cat([token_ids, next_tokens], dim=-1) # batch_size x max_len
|
||||
|
||||
end_mask = next_tokens.eq(_eos_token_id)
|
||||
dones = dones.__or__(end_mask)
|
||||
@ -189,29 +225,31 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
|
||||
return token_ids
|
||||
|
||||
|
||||
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0,
|
||||
def _beam_search_generate(decoder: Seq2SeqDecoder,
|
||||
encoder_output=None, encoder_mask: torch.Tensor = None,
|
||||
prev_tokens: torch.Tensor = None, max_length=20, num_beams=4, temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
|
||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=False,
|
||||
repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor:
|
||||
# 进行beam search
|
||||
device = _get_model_device(decoder)
|
||||
if tokens is None:
|
||||
if bos_token_id is None:
|
||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
||||
if past is None:
|
||||
raise RuntimeError("You have to specify either `past` or `tokens`.")
|
||||
batch_size = past.num_samples()
|
||||
if batch_size is None:
|
||||
raise RuntimeError("Cannot infer the number of samples from `past`.")
|
||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
batch_size = tokens.size(0)
|
||||
if past is not None:
|
||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."
|
||||
|
||||
for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
|
||||
scores, past = decoder.decode(tokens[:, :i + 1],
|
||||
past) # (batch_size, vocab_size), Past
|
||||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
|
||||
if encoder_output is not None:
|
||||
batch_size = encoder_output.size(0)
|
||||
else:
|
||||
assert prev_tokens is not None, "You have to specify either `src_tokens` or `prev_tokens`"
|
||||
batch_size = prev_tokens.size(0)
|
||||
|
||||
device = _get_model_device(decoder)
|
||||
|
||||
if prev_tokens is None:
|
||||
if bos_token_id is None:
|
||||
raise RuntimeError("You have to specify either `prev_tokens` or `bos_token_id`.")
|
||||
|
||||
prev_tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
|
||||
for i in range(prev_tokens.size(1)): # 如果输入的长度较长,先decode
|
||||
scores = decoder.decode(prev_tokens[:, :i + 1], encoder_output, encoder_mask)
|
||||
|
||||
vocab_size = scores.size(1)
|
||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
|
||||
|
||||
@ -225,15 +263,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
||||
# 得到(batch_size, num_beams), (batch_size, num_beams)
|
||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
||||
|
||||
# 根据index来做顺序的调转
|
||||
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
||||
indices = indices.repeat_interleave(num_beams)
|
||||
decoder.reorder_past(indices, past)
|
||||
decoder.reorder_past(indices)
|
||||
prev_tokens = prev_tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
|
||||
|
||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
|
||||
# 记录生成好的token (batch_size', cur_len)
|
||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
|
||||
token_ids = torch.cat([prev_tokens, next_tokens.view(-1, 1)], dim=-1)
|
||||
dones = [False] * batch_size
|
||||
tokens = next_tokens.view(-1, 1)
|
||||
|
||||
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
||||
|
||||
@ -247,7 +285,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
||||
|
||||
while cur_len < max_length:
|
||||
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past
|
||||
scores = decoder.decode(token_ids, encoder_output, encoder_mask) # batch_size * num_beams x vocab_size
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
@ -300,9 +338,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
beam_scores = _next_scores.view(-1)
|
||||
|
||||
# 更改past状态, 重组token_ids
|
||||
# 重组past/encoder状态, 重组token_ids
|
||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
|
||||
decoder.reorder_past(reorder_inds, past)
|
||||
decoder.reorder_past(reorder_inds)
|
||||
|
||||
flag = True
|
||||
if cur_len + 1 == max_length:
|
||||
@ -327,8 +365,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
|
||||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
||||
|
||||
# 重新组织token_ids的状态
|
||||
tokens = _next_tokens
|
||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)
|
||||
cur_tokens = _next_tokens
|
||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), cur_tokens], dim=-1)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
|
||||
@ -436,38 +474,3 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO 需要检查一下greedy_generate和sample_generate是否正常工作。
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DummyDecoder(nn.Module):
|
||||
def __init__(self, num_words):
|
||||
super().__init__()
|
||||
self.num_words = num_words
|
||||
|
||||
def decode(self, tokens, past):
|
||||
batch_size = tokens.size(0)
|
||||
return torch.randn(batch_size, self.num_words), past
|
||||
|
||||
def reorder_past(self, indices, past):
|
||||
return past
|
||||
|
||||
|
||||
num_words = 10
|
||||
batch_size = 3
|
||||
decoder = DummyDecoder(num_words)
|
||||
|
||||
tokens = greedy_generate(decoder=decoder, tokens=torch.zeros(batch_size, 1).long(), past=None, max_length=20,
|
||||
num_beams=2,
|
||||
bos_token_id=0, eos_token_id=num_words - 1,
|
||||
repetition_penalty=1, length_penalty=1.0)
|
||||
print(tokens)
|
||||
|
||||
tokens = sample_generate(decoder, tokens=torch.zeros(batch_size, 1).long(),
|
||||
past=None, max_length=20, num_beams=2, temperature=1.0, top_k=50,
|
||||
top_p=1.0, bos_token_id=0, eos_token_id=num_words - 1, repetition_penalty=1.0,
|
||||
length_penalty=1.0)
|
||||
print(tokens)
|
||||
|
@ -31,8 +31,9 @@ __all__ = [
|
||||
"BiAttention",
|
||||
"SelfAttention",
|
||||
|
||||
"BiLSTMEncoder",
|
||||
"TransformerSeq2SeqEncoder"
|
||||
"LSTMSeq2SeqEncoder",
|
||||
"TransformerSeq2SeqEncoder",
|
||||
"Seq2SeqEncoder"
|
||||
]
|
||||
|
||||
from .attention import MultiHeadAttention, BiAttention, SelfAttention
|
||||
@ -45,4 +46,4 @@ from .star_transformer import StarTransformer
|
||||
from .transformer import TransformerEncoder
|
||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU
|
||||
|
||||
from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder
|
||||
from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder
|
||||
|
@ -1,48 +1,238 @@
|
||||
__all__ = [
|
||||
"TransformerSeq2SeqEncoder",
|
||||
"BiLSTMEncoder"
|
||||
]
|
||||
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from ...modules.encoder import LSTM
|
||||
from ...core.utils import seq_len_to_mask
|
||||
from torch.nn import TransformerEncoder
|
||||
from torch.nn import LayerNorm
|
||||
import torch.nn.functional as F
|
||||
from typing import Union, Tuple
|
||||
import numpy as np
|
||||
from ...core.utils import seq_len_to_mask
|
||||
import math
|
||||
from ...core import Vocabulary
|
||||
from ...modules import LSTM
|
||||
|
||||
|
||||
class TransformerSeq2SeqEncoder(nn.Module):
|
||||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers: int = 6,
|
||||
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
|
||||
super(TransformerSeq2SeqEncoder, self).__init__()
|
||||
self.embed = embed
|
||||
self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head,dim_ff,dropout), num_layers)
|
||||
class MultiheadAttention(nn.Module): # todo 这个要放哪里?
|
||||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.dropout = dropout
|
||||
self.head_dim = d_model // n_head
|
||||
self.layer_idx = layer_idx
|
||||
assert d_model % n_head == 0, "d_model should be divisible by n_head"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
def forward(self, words, seq_len):
|
||||
self.q_proj = nn.Linear(d_model, d_model)
|
||||
self.k_proj = nn.Linear(d_model, d_model)
|
||||
self.v_proj = nn.Linear(d_model, d_model)
|
||||
self.out_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, query, key, value, key_mask=None, attn_mask=None, past=None):
|
||||
"""
|
||||
|
||||
:param words: batch, seq_len
|
||||
:param seq_len:
|
||||
:return: output: (batch, seq_len,dim) ; encoder_mask
|
||||
:param query: batch x seq x dim
|
||||
:param key:
|
||||
:param value:
|
||||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的
|
||||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1
|
||||
:param past: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。
|
||||
:return:
|
||||
"""
|
||||
words = self.embed(words) # batch, seq_len, dim
|
||||
words = words.transpose(0, 1)
|
||||
encoder_mask = seq_len_to_mask(seq_len) # batch, seq
|
||||
words = self.transformer(words, src_key_padding_mask=~encoder_mask) # seq_len,batch,dim
|
||||
assert key.size() == value.size()
|
||||
if past is not None:
|
||||
assert self.layer_idx is not None
|
||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
||||
|
||||
return words.transpose(0, 1), encoder_mask
|
||||
q = self.q_proj(query) # batch x seq x dim
|
||||
q *= self.scaling
|
||||
k = v = None
|
||||
prev_k = prev_v = None
|
||||
|
||||
# 从past中取kv
|
||||
if past is not None: # 说明此时在inference阶段
|
||||
if qkv_same: # 此时在decoder self attention
|
||||
prev_k = past.decoder_prev_key[self.layer_idx]
|
||||
prev_v = past.decoder_prev_value[self.layer_idx]
|
||||
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可
|
||||
k = past.encoder_key[self.layer_idx]
|
||||
v = past.encoder_value[self.layer_idx]
|
||||
|
||||
if k is None:
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
|
||||
if prev_k is not None:
|
||||
k = torch.cat((prev_k, k), dim=1)
|
||||
v = torch.cat((prev_v, v), dim=1)
|
||||
|
||||
# 更新past
|
||||
if past is not None:
|
||||
if qkv_same:
|
||||
past.decoder_prev_key[self.layer_idx] = k
|
||||
past.decoder_prev_value[self.layer_idx] = v
|
||||
else:
|
||||
past.encoder_key[self.layer_idx] = k
|
||||
past.encoder_value[self.layer_idx] = v
|
||||
|
||||
# 开始计算attention
|
||||
batch_size, q_len, d_model = query.size()
|
||||
k_len, v_len = k.size(1), v.size(1)
|
||||
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim)
|
||||
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim)
|
||||
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim)
|
||||
|
||||
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head
|
||||
if key_mask is not None:
|
||||
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,n_head
|
||||
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf'))
|
||||
|
||||
if attn_mask is not None:
|
||||
_attn_mask = ~attn_mask[None, :, :, None].bool() # 1,q_len,k_len,n_head
|
||||
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf'))
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=2)
|
||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
|
||||
output = output.reshape(batch_size, q_len, -1)
|
||||
output = self.out_proj(output) # batch,q_len,dim
|
||||
|
||||
return output, attn_weights
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
|
||||
def set_layer_idx(self, layer_idx):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
|
||||
class BiLSTMEncoder(nn.Module):
|
||||
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
|
||||
class TransformerSeq2SeqEncoderLayer(nn.Module):
|
||||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048,
|
||||
dropout: float = 0.1):
|
||||
super(TransformerSeq2SeqEncoderLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.dim_ff = dim_ff
|
||||
self.dropout = dropout
|
||||
|
||||
self.self_attn = MultiheadAttention(d_model, n_head, dropout)
|
||||
self.attn_layer_norm = LayerNorm(d_model)
|
||||
self.ffn_layer_norm = LayerNorm(d_model)
|
||||
|
||||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.dim_ff, self.d_model),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, encoder_mask):
|
||||
"""
|
||||
|
||||
:param x: batch,src_seq,dim
|
||||
:param encoder_mask: batch,src_seq
|
||||
:return:
|
||||
"""
|
||||
# attention
|
||||
residual = x
|
||||
x = self.attn_layer_norm(x)
|
||||
x, _ = self.self_attn(query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_mask=encoder_mask)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
|
||||
# ffn
|
||||
residual = x
|
||||
x = self.ffn_layer_norm(x)
|
||||
x = self.ffn(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Seq2SeqEncoder(nn.Module):
|
||||
def __init__(self, vocab):
|
||||
super().__init__()
|
||||
self.vocab = vocab
|
||||
|
||||
def forward(self, src_words, src_seq_len):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder):
|
||||
def __init__(self, vocab: Vocabulary, embed: nn.Module, pos_embed: nn.Module = None, num_layers: int = 6,
|
||||
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
|
||||
super(TransformerSeq2SeqEncoder, self).__init__(vocab)
|
||||
self.embed = embed
|
||||
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
|
||||
self.embed_scale = math.sqrt(d_model)
|
||||
self.pos_embed = pos_embed
|
||||
self.num_layers = num_layers
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.dim_ff = dim_ff
|
||||
self.dropout = dropout
|
||||
|
||||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout)
|
||||
for _ in range(num_layers)])
|
||||
self.layer_norm = LayerNorm(d_model)
|
||||
|
||||
def forward(self, src_words, src_seq_len):
|
||||
"""
|
||||
|
||||
:param src_words: batch, src_seq_len
|
||||
:param src_seq_len: [batch]
|
||||
:return:
|
||||
"""
|
||||
batch_size, max_src_len = src_words.size()
|
||||
device = src_words.device
|
||||
x = self.embed(src_words) * self.embed_scale # batch, seq, dim
|
||||
if self.pos_embed is not None:
|
||||
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device)
|
||||
x += self.pos_embed(position)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
encoder_mask = seq_len_to_mask(src_seq_len)
|
||||
encoder_mask = encoder_mask.to(device)
|
||||
|
||||
for layer in self.layer_stacks:
|
||||
x = layer(x, encoder_mask)
|
||||
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, encoder_mask
|
||||
|
||||
|
||||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder):
|
||||
def __init__(self, vocab: Vocabulary, embed: nn.Module, num_layers: int = 3, hidden_size: int = 400,
|
||||
dropout: float = 0.3, bidirectional=True):
|
||||
super().__init__(vocab)
|
||||
self.embed = embed
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.hidden_size = hidden_size
|
||||
self.bidirectional = bidirectional
|
||||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=bidirectional,
|
||||
batch_first=True, dropout=dropout, num_layers=num_layers)
|
||||
|
||||
def forward(self, words, seq_len):
|
||||
words = self.embed(words)
|
||||
words, hx = self.lstm(words, seq_len)
|
||||
def forward(self, src_words, src_seq_len):
|
||||
batch_size = src_words.size(0)
|
||||
device = src_words.device
|
||||
x = self.embed(src_words)
|
||||
x, (final_hidden, final_cell) = self.lstm(x, src_seq_len)
|
||||
encoder_mask = seq_len_to_mask(src_seq_len).to(device)
|
||||
|
||||
return words, hx
|
||||
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim
|
||||
|
||||
def concat_bidir(input):
|
||||
output = input.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous()
|
||||
return output.view(self.num_layers, batch_size, -1)
|
||||
|
||||
if self.bidirectional:
|
||||
final_hidden = concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input
|
||||
final_cell = concat_bidir(final_cell)
|
||||
|
||||
return (x, (final_hidden, final_cell)), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return
|
||||
|
@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer
|
||||
|
||||
__author__ = "Yu-Hsiang Huang"
|
||||
|
||||
|
||||
def get_non_pad_mask(seq):
|
||||
assert seq.dim() == 2
|
||||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||
''' Sinusoid position encoding table '''
|
||||
|
||||
@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||
|
||||
return torch.FloatTensor(sinusoid_table)
|
||||
|
||||
|
||||
def get_attn_key_pad_mask(seq_k, seq_q):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
|
||||
@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q):
|
||||
|
||||
return padding_mask
|
||||
|
||||
|
||||
def get_subsequent_mask(seq):
|
||||
''' For masking out the subsequent info. '''
|
||||
|
||||
@ -51,6 +55,7 @@ def get_subsequent_mask(seq):
|
||||
|
||||
return subsequent_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
''' A encoder model with self attention mechanism. '''
|
||||
|
||||
@ -98,6 +103,7 @@ class Encoder(nn.Module):
|
||||
return enc_output, enc_slf_attn_list
|
||||
return enc_output,
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
''' A decoder model with self attention mechanism. '''
|
||||
|
||||
@ -152,6 +158,7 @@ class Decoder(nn.Module):
|
||||
return dec_output, dec_slf_attn_list, dec_enc_attn_list
|
||||
return dec_output,
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
''' A sequence to sequence model with attention mechanism. '''
|
||||
|
||||
@ -181,8 +188,8 @@ class Transformer(nn.Module):
|
||||
nn.init.xavier_normal_(self.tgt_word_prj.weight)
|
||||
|
||||
assert d_model == d_word_vec, \
|
||||
'To facilitate the residual connections, \
|
||||
the dimensions of all module outputs shall be the same.'
|
||||
'To facilitate the residual connections, \
|
||||
the dimensions of all module outputs shall be the same.'
|
||||
|
||||
if tgt_emb_prj_weight_sharing:
|
||||
# Share the weight matrix between target word embedding & the final logit dense layer
|
||||
@ -194,7 +201,7 @@ class Transformer(nn.Module):
|
||||
if emb_src_tgt_weight_sharing:
|
||||
# Share the weight matrix between source & target word embeddings
|
||||
assert n_src_vocab == n_tgt_vocab, \
|
||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
|
||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
|
||||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
|
||||
|
||||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
|
||||
|
@ -2,8 +2,10 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
|
||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
|
||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
|
||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \
|
||||
LSTMSeq2SeqDecoder
|
||||
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP.embeddings import StaticEmbedding
|
||||
from fastNLP.core.utils import seq_len_to_mask
|
||||
@ -15,22 +17,17 @@ class TestTransformerSeq2SeqDecoder(unittest.TestCase):
|
||||
vocab.add_word_lst("Another test !".split())
|
||||
embed = StaticEmbedding(vocab, embedding_dim=512)
|
||||
|
||||
encoder = TransformerSeq2SeqEncoder(embed)
|
||||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True)
|
||||
args = TransformerSeq2SeqModel.add_args()
|
||||
model = TransformerSeq2SeqModel.build_model(args, vocab, vocab)
|
||||
|
||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
|
||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
|
||||
src_seq_len = torch.LongTensor([3, 2])
|
||||
|
||||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
|
||||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask)
|
||||
output = model(src_words_idx, src_seq_len, tgt_words_idx)
|
||||
print(output)
|
||||
|
||||
decoder_outputs = decoder(tgt_words_idx, past)
|
||||
|
||||
print(decoder_outputs)
|
||||
print(mask)
|
||||
|
||||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
|
||||
# self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
|
||||
|
||||
def test_decode(self):
|
||||
pass # todo
|
||||
|
Loading…
Reference in New Issue
Block a user