pytorch1.2版本中新增boolTensor类型,所有的masked_fill必须为ByteTensor类型的索引,修改fastNLP以适配

This commit is contained in:
yh_cc 2019-08-22 15:51:44 +08:00
parent e2232ac39f
commit f18ab642d7
5 changed files with 9 additions and 9 deletions

View File

@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding):
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module):
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._wordpiece_unk_index)

View File

@ -63,7 +63,7 @@ class Embedding(nn.Module):
"""
if self.word_dropout>0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大越多位置为1
words = words.masked_fill(mask, self.unk_index)
words = self.embed(words)
return self.dropout(words)
@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module):
"""
if self.word_dropout > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
return words

View File

@ -150,7 +150,7 @@ class GraphParser(BaseModel):
"""
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (mask == 0).byte()
flip_mask = mask.eq(0)
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:

View File

@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module):
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0)
if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
# trans_socre [L-1, B]
mask = mask.byte()
mask = mask.eq(1)
flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B
mask = mask.transpose(0, 1).data.eq(1) # L, B
# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)

View File

@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
"compatible."
logits = logits.transpose(0, 1).data # L, B, H
if mask is not None:
mask = mask.transpose(0, 1).data.byte() # L, B
mask = mask.transpose(0, 1).data.eq(1) # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)