Merge pull request #145 from fastnlp/choosewhatulike-patch-1

fix for changing torch API
This commit is contained in:
Xipeng Qiu 2019-05-03 15:03:12 +08:00 committed by GitHub
commit 863a99f741
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,7 +41,7 @@ class VarRnnCellWrapper(nn.Module):
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
cell = self.cell
if is_reversed:
@ -127,10 +127,10 @@ class VarRNNBase(nn.Module):
seq_len = input.size(1) if self.batch_first else input.size(0)
max_batch_size = input.size(0) if self.batch_first else input.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
else:
max_batch_size = int(input.batch_sizes[0])
input, batch_sizes = input
input, batch_sizes = input.data, input.batch_sizes
if hx is None:
hx = input.new_zeros(self.num_layers * self.num_directions,