mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
Rewrite classification model, add intialization for conv_maxpool
This commit is contained in:
parent
9dc32f68a7
commit
d910ae3c77
@ -269,7 +269,7 @@ class ClassPreprocess(BasePreprocess):
|
|||||||
|
|
||||||
for word in sent:
|
for word in sent:
|
||||||
if word not in word2index:
|
if word not in word2index:
|
||||||
word2index[word[0]] = len(word2index)
|
word2index[word] = len(word2index)
|
||||||
return word2index, label2index
|
return word2index, label2index
|
||||||
|
|
||||||
def to_index(self, data):
|
def to_index(self, data):
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# import torch.nn.functional as F
|
# import torch.nn.functional as F
|
||||||
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool
|
import fastNLP.modules.encoder as encoder
|
||||||
|
|
||||||
|
|
||||||
class CNNText(torch.nn.Module):
|
class CNNText(torch.nn.Module):
|
||||||
@ -18,22 +18,22 @@ class CNNText(torch.nn.Module):
|
|||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(CNNText, self).__init__()
|
super(CNNText, self).__init__()
|
||||||
|
|
||||||
class_num = args["num_classes"]
|
num_classes = args["num_classes"]
|
||||||
kernel_nums = [100, 100, 100]
|
kernel_nums = [100, 100, 100]
|
||||||
kernel_sizes = [3, 4, 5]
|
kernel_sizes = [3, 4, 5]
|
||||||
embed_num = args["vocab_size"]
|
vocab_size = args["vocab_size"]
|
||||||
embed_dim = 300
|
embed_dim = 300
|
||||||
pretrained_embed = None
|
pretrained_embed = None
|
||||||
drop_prob = 0.5
|
drop_prob = 0.5
|
||||||
|
|
||||||
# no support for pre-trained embedding currently
|
# no support for pre-trained embedding currently
|
||||||
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0)
|
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim)
|
||||||
self.conv_pool = ConvMaxpool(
|
self.conv_pool = encoder.conv_maxpool.ConvMaxpool(
|
||||||
in_channels=embed_dim,
|
in_channels=embed_dim,
|
||||||
out_channels=kernel_nums,
|
out_channels=kernel_nums,
|
||||||
kernel_sizes=kernel_sizes)
|
kernel_sizes=kernel_sizes)
|
||||||
self.dropout = nn.Dropout(drop_prob)
|
self.dropout = nn.Dropout(drop_prob)
|
||||||
self.fc = nn.Linear(sum(kernel_nums), class_num)
|
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.embed(x) # [N,L] -> [N,L,C]
|
x = self.embed(x) # [N,L] -> [N,L,C]
|
||||||
|
@ -2,8 +2,10 @@ from .embedding import Embedding
|
|||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .lstm import Lstm
|
from .lstm import Lstm
|
||||||
from .conv import Conv
|
from .conv import Conv
|
||||||
|
from .conv_maxpool import ConvMaxpool
|
||||||
|
|
||||||
__all__ = ["Lstm",
|
__all__ = ["Lstm",
|
||||||
"Embedding",
|
"Embedding",
|
||||||
"Linear",
|
"Linear",
|
||||||
"Conv"]
|
"Conv",
|
||||||
|
"ConvMaxpool"]
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.init import xavier_uniform_
|
||||||
|
|
||||||
|
|
||||||
class ConvMaxpool(nn.Module):
|
class ConvMaxpool(nn.Module):
|
||||||
@ -21,6 +22,7 @@ class ConvMaxpool(nn.Module):
|
|||||||
if isinstance(kernel_sizes, int):
|
if isinstance(kernel_sizes, int):
|
||||||
out_channels = [out_channels]
|
out_channels = [out_channels]
|
||||||
kernel_sizes = [kernel_sizes]
|
kernel_sizes = [kernel_sizes]
|
||||||
|
|
||||||
self.convs = nn.ModuleList([nn.Conv1d(
|
self.convs = nn.ModuleList([nn.Conv1d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=oc,
|
out_channels=oc,
|
||||||
@ -31,6 +33,9 @@ class ConvMaxpool(nn.Module):
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
for oc, ks in zip(out_channels, kernel_sizes)])
|
for oc, ks in zip(out_channels, kernel_sizes)])
|
||||||
|
|
||||||
|
for conv in self.convs:
|
||||||
|
xavier_uniform_(conv.weight) # weight initialization
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Incorrect kernel sizes: should be list, tuple or int')
|
'Incorrect kernel sizes: should be list, tuple or int')
|
||||||
|
Loading…
Reference in New Issue
Block a user