mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修复BertEmbedding的bug
This commit is contained in:
parent
b0c50f7299
commit
88dafd7f9a
@ -306,11 +306,8 @@ class _WordBertModel(nn.Module):
|
||||
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the "
|
||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")
|
||||
|
||||
|
||||
# +2是由于需要加入[CLS]与[SEP]
|
||||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index)
|
||||
word_pieces[:, 0].fill_(self._cls_index)
|
||||
batch_indexes = torch.arange(batch_size).to(words)
|
||||
attn_masks = torch.zeros_like(word_pieces)
|
||||
# 1. 获取words的word_pieces的id,以及对应的span范围
|
||||
word_indexes = words.tolist()
|
||||
@ -319,8 +316,11 @@ class _WordBertModel(nn.Module):
|
||||
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2:
|
||||
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2]
|
||||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
|
||||
word_pieces[i, len(word_pieces_i)+1] = self._sep_index # 补上sep
|
||||
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1)
|
||||
# 添加[cls]和[sep]
|
||||
word_pieces[:, 0].fill_(self._cls_index)
|
||||
batch_indexes = torch.arange(batch_size).to(words)
|
||||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index
|
||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
|
||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
|
||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
|
||||
|
Loading…
Reference in New Issue
Block a user