1.更新权重下载url; 2.更新seq2seq,方式第一个位置预测eos

This commit is contained in:
yh_cc 2020-12-27 16:01:51 +08:00
parent d4fda68840
commit 0a2f546b70
5 changed files with 32 additions and 18 deletions

View File

@ -259,8 +259,8 @@ def _get_base_url(name):
return url + '/'
else:
URLS = {
'embedding': "http://212.129.155.247/embedding/",
"dataset": "http://212.129.155.247/dataset/"
'embedding': "http://download.fastnlp.top/embedding/",
"dataset": "http://download.fastnlp.top/dataset/"
}
if name.lower() not in URLS:
raise KeyError(f"{name} is not recognized.")

View File

@ -11,7 +11,8 @@ __all__ = ['SequenceGeneratorModel']
class SequenceGeneratorModel(nn.Module):
"""
用于封装Seq2SeqModel使其可以做生成任务
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成训练的时候本模型的forward函数会被调用生成的时候本模型的predict
函数会被调用
"""
@ -46,7 +47,7 @@ class SequenceGeneratorModel(nn.Module):
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
"""
透传调用seq2seq_model的forward
透传调用seq2seq_model的forward
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor tgt_tokens: bsz x max_len'
@ -58,7 +59,7 @@ class SequenceGeneratorModel(nn.Module):
def predict(self, src_tokens, src_seq_len=None):
"""
给定source的内容输出generate的内容
给定source的内容输出generate的内容
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor src_seq_len: bsz

View File

@ -18,10 +18,16 @@ __all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel']
class Seq2SeqModel(nn.Module):
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder):
"""
可以用于在Trainer中训练的Seq2Seq模型正常情况下继承了该函数之后只需要实现classmethod build_model即可
可以用于在Trainer中训练的Seq2Seq模型正常情况下继承了该函数之后只需要实现classmethod build_model即可如果需要使用该模型
进行生成需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 在本模型中forward()会把encoder后的
结果传入到decoder中并将decoder的输出output出来
:param encoder: Encoder
:param decoder: Decoder
:param encoder: Seq2SeqEncoder 对象需要实现对应的forward()函数接受两个参数第一个为bsz x max_len的source tokens, 第二个为
bsz的source的长度需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len
为1的地方需要被attend如果encoder的输出或者输入有变化可以重载本模型的prepare_state()函数或者forward()函数
:param decoder: Seq2SeqDecoder 对象需要实现init_state()函数输出为两个参数第一个为bsz x max_len x hidden_size是
encoder的输出; 第二个为bsz x max_len为encoder输出的mask为0的地方为pad若decoder需要更多输入请重载当前模型的
prepare_state()或forward()函数
"""
super().__init__()
self.encoder = encoder

View File

@ -16,7 +16,7 @@ __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder']
class Seq2SeqDecoder(nn.Module):
"""
Sequence-to-Sequence Decoder的基类一定需要实现forward函数剩下的函数根据需要实现每个Seq2SeqDecoder都应该有相应的State对象
Sequence-to-Sequence Decoder的基类一定需要实现forwarddecode函数剩下的函数根据需要实现每个Seq2SeqDecoder都应该有相应的State对象
用来承载该Decoder所需要的Encoder输出Decoder需要记录的历史信息(例如LSTM的hidden信息)
"""
@ -61,7 +61,7 @@ class Seq2SeqDecoder(nn.Module):
"""
根据states中的内容以及tokens中的内容进行之后的生成
:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出
:param State state: 记录了encoder输出与decoder过去状态
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布
"""

View File

@ -12,9 +12,11 @@ import torch.nn.functional as F
from ...core.utils import _get_model_device
from functools import partial
class SequenceGenerator:
"""
给定一个Seq2SeqDecoderdecode出句子
给定一个Seq2SeqDecoderdecode出句子输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出
第二个参数为stateSequenceGenerator不会对state进行任何操作
"""
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1,
@ -65,7 +67,8 @@ class SequenceGenerator:
"""
:param State state: encoder结果的State, 是与Decoder配套是用的
:param torch.LongTensor,None tokens: batch_size x length, 开始的token
:param torch.LongTensor,None tokens: batch_size x length, 开始的token如果为None则默认添加bos_token作为开头的token
进行生成
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id
"""
@ -168,6 +171,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
@ -261,6 +266,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
@ -322,7 +329,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores)
if do_sample:
if temperature > 0 and temperature != 1:
@ -356,9 +363,9 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
# 接下来需要组装下一个batch的结果。
# 需要选定哪些留下来
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
@ -413,7 +420,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
break
# select the best hypotheses
tgt_len = token_ids.new(batch_size)
tgt_len = token_ids.new_zeros(batch_size)
best = []
for i, hypotheses in enumerate(hypos):
@ -425,7 +432,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
best.append(best_hyp)
# generate target batch
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo