mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
Merge branch 'dev' of github.com:fastnlp/fastNLP into dev
This commit is contained in:
commit
d4fda68840
@ -70,7 +70,7 @@ class LSTM(nn.Module):
|
||||
x = x[sort_idx]
|
||||
else:
|
||||
x = x[:, sort_idx]
|
||||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
|
||||
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first)
|
||||
output, hx = self.lstm(x, hx) # -> [N,L,C]
|
||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
|
||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
|
||||
|
Loading…
Reference in New Issue
Block a user