重新修改ELMO与LSTM DataParallel的问题

This commit is contained in:
yh 2019-06-19 22:19:41 +08:00
parent a137038eb2
commit c4e131a0c5
2 changed files with 4 additions and 3 deletions

View File

@ -762,8 +762,9 @@ class _ElmoModel(nn.Module):
if self.config['encoder']['name'] == 'elmo':
encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len:
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - encoder_output.size(2), encoder_output.size(-1)))
encoder_output = torch.cat([encoder_output, dummy_tensor], 1)
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size,
max_len - encoder_output.size(2), encoder_output.size(-1))
encoder_output = torch.cat([encoder_output, dummy_tensor], 2)
sz = encoder_output.size() # batch_size, max_len, hidden_size
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat([token_embedding, encoder_output], dim=0)

View File

@ -82,7 +82,7 @@ class LSTM(nn.Module):
output = output[:, unsort_idx]
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
if output.size(1) < max_len:
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1)))
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
output = torch.cat([output, dummy_tensor], 1)
else:
output, hx = self.lstm(x, hx)