Merge branch 'dev' of github.com:fastnlp/fastNLP into dev

This commit is contained in:
yh_cc 2020-12-11 14:20:32 +08:00
commit d4fda68840

View File

@ -70,7 +70,7 @@ class LSTM(nn.Module):
x = x[sort_idx]
else:
x = x[:, sort_idx]
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.lstm(x, hx) # -> [N,L,C]
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)