Merge pull request #104 from xuyige/master

fix a bug in init and add dropout in MLP
This commit is contained in:
Yige XU 2018-11-06 20:18:40 +08:00 committed by GitHub
commit 5ec58e3b86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 6 deletions

View File

@ -4,12 +4,13 @@ from fastNLP.modules.utils import initial_parameter
class MLP(nn.Module):
def __init__(self, size_layer, activation='relu', initial_method=None):
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
"""Multilayer Perceptrons as a decoder
:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.
:param dropout: float, the probability of dropout.
.. note::
There is no activation function applying on output layer.
@ -24,6 +25,8 @@ class MLP(nn.Module):
else:
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i]))
self.dropout = nn.Dropout(p=dropout)
actives = {
'relu': nn.ReLU(),
'tanh': nn.Tanh(),
@ -38,8 +41,8 @@ class MLP(nn.Module):
def forward(self, x):
for layer in self.hiddens:
x = self.hidden_active(layer(x))
x = self.output(x)
x = self.dropout(self.hidden_active(layer(x)))
x = self.dropout(self.output(x))
return x

View File

@ -32,9 +32,9 @@ def initial_parameter(net, initial_method=None):
elif initial_method == 'xavier_normal':
init_method = init.xavier_normal_
elif initial_method == 'kaiming_normal' or initial_method == 'msra':
init_method = init.kaiming_normal
init_method = init.kaiming_normal_
elif initial_method == 'kaiming_uniform':
init_method = init.kaiming_normal
init_method = init.kaiming_uniform_
elif initial_method == 'orthogonal':
init_method = init.orthogonal_
elif initial_method == 'sparse':
@ -42,7 +42,7 @@ def initial_parameter(net, initial_method=None):
elif initial_method == 'normal':
init_method = init.normal_
elif initial_method == 'uniform':
initial_method = init.uniform_
init_method = init.uniform_
else:
init_method = init.xavier_normal_