mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
pytorch1.2版本中新增boolTensor类型,所有的masked_fill必须为ByteTensor类型的索引,修改fastNLP以适配
This commit is contained in:
parent
e2232ac39f
commit
f18ab642d7
@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding):
|
||||
if self._word_sep_index: # 不能drop sep
|
||||
sep_mask = words.eq(self._word_sep_index)
|
||||
mask = torch.ones_like(words).float() * self.word_dropout
|
||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
|
||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
|
||||
words = words.masked_fill(mask, self._word_unk_index)
|
||||
if self._word_sep_index:
|
||||
words.masked_fill_(sep_mask, self._word_sep_index)
|
||||
@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module):
|
||||
if self._word_sep_index: # 不能drop sep
|
||||
sep_mask = words.eq(self._wordpiece_unk_index)
|
||||
mask = torch.ones_like(words).float() * self.word_dropout
|
||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
|
||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
|
||||
words = words.masked_fill(mask, self._word_unk_index)
|
||||
if self._word_sep_index:
|
||||
words.masked_fill_(sep_mask, self._wordpiece_unk_index)
|
||||
|
@ -63,7 +63,7 @@ class Embedding(nn.Module):
|
||||
"""
|
||||
if self.word_dropout>0 and self.training:
|
||||
mask = torch.ones_like(words).float() * self.word_dropout
|
||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
|
||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
|
||||
words = words.masked_fill(mask, self.unk_index)
|
||||
words = self.embed(words)
|
||||
return self.dropout(words)
|
||||
@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module):
|
||||
"""
|
||||
if self.word_dropout > 0 and self.training:
|
||||
mask = torch.ones_like(words).float() * self.word_dropout
|
||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
|
||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
|
||||
words = words.masked_fill(mask, self._word_unk_index)
|
||||
return words
|
||||
|
||||
|
@ -150,7 +150,7 @@ class GraphParser(BaseModel):
|
||||
"""
|
||||
_, seq_len, _ = arc_matrix.shape
|
||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
|
||||
flip_mask = (mask == 0).byte()
|
||||
flip_mask = mask.eq(0)
|
||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
|
||||
_, heads = torch.max(matrix, dim=2)
|
||||
if mask is not None:
|
||||
|
@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module):
|
||||
trans_score = self.trans_m.view(1, n_tags, n_tags)
|
||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
|
||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
|
||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
|
||||
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0)
|
||||
|
||||
if self.include_start_end_trans:
|
||||
alpha = alpha + self.end_scores.view(1, -1)
|
||||
@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module):
|
||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
|
||||
|
||||
# trans_socre [L-1, B]
|
||||
mask = mask.byte()
|
||||
mask = mask.eq(1)
|
||||
flip_mask = mask.eq(0)
|
||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
|
||||
# emit_score [L, B]
|
||||
@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module):
|
||||
"""
|
||||
batch_size, seq_len, n_tags = logits.size()
|
||||
logits = logits.transpose(0, 1).data # L, B, H
|
||||
mask = mask.transpose(0, 1).data.byte() # L, B
|
||||
mask = mask.transpose(0, 1).data.eq(1) # L, B
|
||||
|
||||
# dp
|
||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
|
||||
|
@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
|
||||
"compatible."
|
||||
logits = logits.transpose(0, 1).data # L, B, H
|
||||
if mask is not None:
|
||||
mask = mask.transpose(0, 1).data.byte() # L, B
|
||||
mask = mask.transpose(0, 1).data.eq(1) # L, B
|
||||
else:
|
||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user