mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
LSTM修改错误
This commit is contained in:
parent
6b9bc007ee
commit
0f4cf30301
@ -82,12 +82,12 @@ class LSTM(nn.Module):
|
||||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
|
||||
if self.batch_first:
|
||||
if output.size(1) < max_len:
|
||||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 0)
|
||||
else:
|
||||
if output.size(0) < max_len:
|
||||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 1)
|
||||
else:
|
||||
if output.size(0) < max_len:
|
||||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 0)
|
||||
else:
|
||||
output, hx = self.lstm(x, hx)
|
||||
return output, hx
|
||||
|
Loading…
Reference in New Issue
Block a user