LSTM修改错误

This commit is contained in:
yh 2019-06-19 23:59:40 +08:00
parent 6b9bc007ee
commit 0f4cf30301

View File

@ -82,12 +82,12 @@ class LSTM(nn.Module):
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if self.batch_first:
if output.size(1) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else:
output, hx = self.lstm(x, hx)
return output, hx