mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
修复bert embedding的bug
This commit is contained in:
parent
60a535db08
commit
14d048f340
@ -420,11 +420,11 @@ class _WordBertModel(nn.Module):
|
||||
if self.pool_method == 'first':
|
||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
|
||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
|
||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
elif self.pool_method == 'last':
|
||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1
|
||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
|
||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
|
||||
for l_index, l in enumerate(self.layers):
|
||||
output_layer = bert_outputs[l]
|
||||
@ -437,12 +437,12 @@ class _WordBertModel(nn.Module):
|
||||
# 从word_piece collapse到word的表示
|
||||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
|
||||
if self.pool_method == 'first':
|
||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
|
||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
|
||||
|
||||
elif self.pool_method == 'last':
|
||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
|
||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
|
||||
elif self.pool_method == 'max':
|
||||
|
Loading…
Reference in New Issue
Block a user