mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
1.更新权重下载url; 2.更新seq2seq,方式第一个位置预测eos
This commit is contained in:
parent
d4fda68840
commit
0a2f546b70
@ -259,8 +259,8 @@ def _get_base_url(name):
|
||||
return url + '/'
|
||||
else:
|
||||
URLS = {
|
||||
'embedding': "http://212.129.155.247/embedding/",
|
||||
"dataset": "http://212.129.155.247/dataset/"
|
||||
'embedding': "http://download.fastnlp.top/embedding/",
|
||||
"dataset": "http://download.fastnlp.top/dataset/"
|
||||
}
|
||||
if name.lower() not in URLS:
|
||||
raise KeyError(f"{name} is not recognized.")
|
||||
|
@ -11,7 +11,8 @@ __all__ = ['SequenceGeneratorModel']
|
||||
|
||||
class SequenceGeneratorModel(nn.Module):
|
||||
"""
|
||||
用于封装Seq2SeqModel使其可以做生成任务
|
||||
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict
|
||||
函数会被调用。
|
||||
|
||||
"""
|
||||
|
||||
@ -46,7 +47,7 @@ class SequenceGeneratorModel(nn.Module):
|
||||
|
||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
|
||||
"""
|
||||
透传调用seq2seq_model的forward
|
||||
透传调用seq2seq_model的forward。
|
||||
|
||||
:param torch.LongTensor src_tokens: bsz x max_len
|
||||
:param torch.LongTensor tgt_tokens: bsz x max_len'
|
||||
@ -58,7 +59,7 @@ class SequenceGeneratorModel(nn.Module):
|
||||
|
||||
def predict(self, src_tokens, src_seq_len=None):
|
||||
"""
|
||||
给定source的内容,输出generate的内容
|
||||
给定source的内容,输出generate的内容。
|
||||
|
||||
:param torch.LongTensor src_tokens: bsz x max_len
|
||||
:param torch.LongTensor src_seq_len: bsz
|
||||
|
@ -18,10 +18,16 @@ __all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel']
|
||||
class Seq2SeqModel(nn.Module):
|
||||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder):
|
||||
"""
|
||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。
|
||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型
|
||||
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的
|
||||
结果传入到decoder中,并将decoder的输出output出来。
|
||||
|
||||
:param encoder: Encoder
|
||||
:param decoder: Decoder
|
||||
:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为
|
||||
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len
|
||||
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数
|
||||
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是
|
||||
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的
|
||||
prepare_state()或forward()函数
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
|
@ -16,7 +16,7 @@ __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder']
|
||||
|
||||
class Seq2SeqDecoder(nn.Module):
|
||||
"""
|
||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象
|
||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象
|
||||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。
|
||||
|
||||
"""
|
||||
@ -61,7 +61,7 @@ class Seq2SeqDecoder(nn.Module):
|
||||
"""
|
||||
根据states中的内容,以及tokens中的内容进行之后的生成。
|
||||
|
||||
:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。
|
||||
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。
|
||||
:param State state: 记录了encoder输出与decoder过去状态
|
||||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布
|
||||
"""
|
||||
|
@ -12,9 +12,11 @@ import torch.nn.functional as F
|
||||
from ...core.utils import _get_model_device
|
||||
from functools import partial
|
||||
|
||||
|
||||
class SequenceGenerator:
|
||||
"""
|
||||
给定一个Seq2SeqDecoder,decode出句子
|
||||
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出,
|
||||
第二个参数为state。SequenceGenerator不会对state进行任何操作。
|
||||
|
||||
"""
|
||||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1,
|
||||
@ -65,7 +67,8 @@ class SequenceGenerator:
|
||||
"""
|
||||
|
||||
:param State state: encoder结果的State, 是与Decoder配套是用的
|
||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token
|
||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token
|
||||
进行生成。
|
||||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id
|
||||
"""
|
||||
|
||||
@ -168,6 +171,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state
|
||||
if _eos_token_id!=-1: # 防止第一个位置为结束
|
||||
scores[:, _eos_token_id] = -1e12
|
||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
||||
cur_len = token_ids.size(1)
|
||||
@ -261,6 +266,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度
|
||||
if _eos_token_id!=-1: # 防止第一个位置为结束
|
||||
scores[:, _eos_token_id] = -1e12
|
||||
vocab_size = scores.size(1)
|
||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
|
||||
|
||||
@ -322,7 +329,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
max_len_eos_mask = max_lengths.eq(cur_len+1)
|
||||
eos_scores = scores[:, _eos_token_id]
|
||||
# 如果已经达到最大长度,就把eos的分数加大
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores)
|
||||
|
||||
if do_sample:
|
||||
if temperature > 0 and temperature != 1:
|
||||
@ -356,9 +363,9 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
|
||||
# 接下来需要组装下一个batch的结果。
|
||||
# 需要选定哪些留下来
|
||||
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
|
||||
next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
|
||||
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
|
||||
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
|
||||
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
|
||||
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
|
||||
|
||||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos
|
||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
|
||||
@ -413,7 +420,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
break
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = token_ids.new(batch_size)
|
||||
tgt_len = token_ids.new_zeros(batch_size)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(hypos):
|
||||
@ -425,7 +432,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
best.append(best_hyp)
|
||||
|
||||
# generate target batch
|
||||
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, :tgt_len[i]] = hypo
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user