LSTM中修复潜在的DataParallel可能存在的问题, 并且删除init_method参数

This commit is contained in:
yh 2019-06-19 23:49:01 +08:00
parent 8a766f070b
commit 6b9bc007ee

View File

@ -19,7 +19,7 @@ class LSTM(nn.Module):
别名:class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题
为1; 且可以应对DataParallel中LSTM的使用问题
:param input_size: 输入 `x` 的特征维度
:param hidden_size: 隐状态 `h` 的特征维度.
@ -32,13 +32,12 @@ class LSTM(nn.Module):
"""
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None):
bidirectional=False, bias=True):
super(LSTM, self).__init__()
self.batch_first = batch_first
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.init_param()
initial_parameter(self, initial_method)
def init_param(self):
for name, param in self.named_parameters():
@ -81,9 +80,14 @@ class LSTM(nn.Module):
else:
output = output[:, unsort_idx]
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if output.size(1) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
if self.batch_first:
if output.size(1) < max_len:
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
output = torch.cat([output, dummy_tensor], 0)
else:
if output.size(0) < max_len:
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
else:
output, hx = self.lstm(x, hx)
return output, hx