mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
LSTM中修复潜在的DataParallel可能存在的问题, 并且删除init_method参数
This commit is contained in:
parent
8a766f070b
commit
6b9bc007ee
@ -19,7 +19,7 @@ class LSTM(nn.Module):
|
||||
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`
|
||||
|
||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
|
||||
为1; 且可以应对DataParallel中LSTM的使用问题
|
||||
为1; 且可以应对DataParallel中LSTM的使用问题。
|
||||
|
||||
:param input_size: 输入 `x` 的特征维度
|
||||
:param hidden_size: 隐状态 `h` 的特征维度.
|
||||
@ -32,13 +32,12 @@ class LSTM(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
|
||||
bidirectional=False, bias=True, initial_method=None):
|
||||
bidirectional=False, bias=True):
|
||||
super(LSTM, self).__init__()
|
||||
self.batch_first = batch_first
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
self.init_param()
|
||||
initial_parameter(self, initial_method)
|
||||
|
||||
def init_param(self):
|
||||
for name, param in self.named_parameters():
|
||||
@ -81,9 +80,14 @@ class LSTM(nn.Module):
|
||||
else:
|
||||
output = output[:, unsort_idx]
|
||||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
|
||||
if output.size(1) < max_len:
|
||||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 1)
|
||||
if self.batch_first:
|
||||
if output.size(1) < max_len:
|
||||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 0)
|
||||
else:
|
||||
if output.size(0) < max_len:
|
||||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 1)
|
||||
else:
|
||||
output, hx = self.lstm(x, hx)
|
||||
return output, hx
|
||||
|
Loading…
Reference in New Issue
Block a user