mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
Merge pull request #104 from xuyige/master
fix a bug in init and add dropout in MLP
This commit is contained in:
commit
5ec58e3b86
@ -4,12 +4,13 @@ from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, size_layer, activation='relu', initial_method=None):
|
||||
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
|
||||
"""Multilayer Perceptrons as a decoder
|
||||
|
||||
:param size_layer: list of int, define the size of MLP layers.
|
||||
:param activation: str or function, the activation function for hidden layers.
|
||||
:param initial_method: str, the name of init method.
|
||||
:param dropout: float, the probability of dropout.
|
||||
|
||||
.. note::
|
||||
There is no activation function applying on output layer.
|
||||
@ -24,6 +25,8 @@ class MLP(nn.Module):
|
||||
else:
|
||||
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i]))
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
actives = {
|
||||
'relu': nn.ReLU(),
|
||||
'tanh': nn.Tanh(),
|
||||
@ -38,8 +41,8 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.hiddens:
|
||||
x = self.hidden_active(layer(x))
|
||||
x = self.output(x)
|
||||
x = self.dropout(self.hidden_active(layer(x)))
|
||||
x = self.dropout(self.output(x))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -32,9 +32,9 @@ def initial_parameter(net, initial_method=None):
|
||||
elif initial_method == 'xavier_normal':
|
||||
init_method = init.xavier_normal_
|
||||
elif initial_method == 'kaiming_normal' or initial_method == 'msra':
|
||||
init_method = init.kaiming_normal
|
||||
init_method = init.kaiming_normal_
|
||||
elif initial_method == 'kaiming_uniform':
|
||||
init_method = init.kaiming_normal
|
||||
init_method = init.kaiming_uniform_
|
||||
elif initial_method == 'orthogonal':
|
||||
init_method = init.orthogonal_
|
||||
elif initial_method == 'sparse':
|
||||
@ -42,7 +42,7 @@ def initial_parameter(net, initial_method=None):
|
||||
elif initial_method == 'normal':
|
||||
init_method = init.normal_
|
||||
elif initial_method == 'uniform':
|
||||
initial_method = init.uniform_
|
||||
init_method = init.uniform_
|
||||
else:
|
||||
init_method = init.xavier_normal_
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user