diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 2966426a..1cc0dec1 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -19,7 +19,7 @@ class LSTM(nn.Module): 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 - 为1; 且可以应对DataParallel中LSTM的使用问题 + 为1; 且可以应对DataParallel中LSTM的使用问题。 :param input_size: 输入 `x` 的特征维度 :param hidden_size: 隐状态 `h` 的特征维度. @@ -32,13 +32,12 @@ class LSTM(nn.Module): """ def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, - bidirectional=False, bias=True, initial_method=None): + bidirectional=False, bias=True): super(LSTM, self).__init__() self.batch_first = batch_first self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) self.init_param() - initial_parameter(self, initial_method) def init_param(self): for name, param in self.named_parameters(): @@ -81,9 +80,14 @@ class LSTM(nn.Module): else: output = output[:, unsort_idx] # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 - if output.size(1) < max_len: - dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) - output = torch.cat([output, dummy_tensor], 1) + if self.batch_first: + if output.size(1) < max_len: + dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) + output = torch.cat([output, dummy_tensor], 0) + else: + if output.size(0) < max_len: + dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) + output = torch.cat([output, dummy_tensor], 1) else: output, hx = self.lstm(x, hx) return output, hx