mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
Merge pull request #145 from fastnlp/choosewhatulike-patch-1
fix for changing torch API
This commit is contained in:
commit
863a99f741
@ -41,7 +41,7 @@ class VarRnnCellWrapper(nn.Module):
|
||||
return torch.cat([hi, h0[:h0_size]], dim=0)
|
||||
return hi[:size]
|
||||
is_lstm = isinstance(hidden, tuple)
|
||||
input, batch_sizes = input_x
|
||||
input, batch_sizes = input_x.data, input_x.batch_sizes
|
||||
output = []
|
||||
cell = self.cell
|
||||
if is_reversed:
|
||||
@ -127,10 +127,10 @@ class VarRNNBase(nn.Module):
|
||||
seq_len = input.size(1) if self.batch_first else input.size(0)
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
|
||||
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
|
||||
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
|
||||
else:
|
||||
max_batch_size = int(input.batch_sizes[0])
|
||||
input, batch_sizes = input
|
||||
input, batch_sizes = input.data, input.batch_sizes
|
||||
|
||||
if hx is None:
|
||||
hx = input.new_zeros(self.num_layers * self.num_directions,
|
||||
|
Loading…
Reference in New Issue
Block a user