Rewrite classification model, add intialization for conv_maxpool

This commit is contained in:
Ke Zhen 2018-09-02 16:39:36 +08:00
parent 9dc32f68a7
commit d910ae3c77
4 changed files with 15 additions and 8 deletions

View File

@ -269,7 +269,7 @@ class ClassPreprocess(BasePreprocess):
for word in sent: for word in sent:
if word not in word2index: if word not in word2index:
word2index[word[0]] = len(word2index) word2index[word] = len(word2index)
return word2index, label2index return word2index, label2index
def to_index(self, data): def to_index(self, data):

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
# import torch.nn.functional as F # import torch.nn.functional as F
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool import fastNLP.modules.encoder as encoder
class CNNText(torch.nn.Module): class CNNText(torch.nn.Module):
@ -18,22 +18,22 @@ class CNNText(torch.nn.Module):
def __init__(self, args): def __init__(self, args):
super(CNNText, self).__init__() super(CNNText, self).__init__()
class_num = args["num_classes"] num_classes = args["num_classes"]
kernel_nums = [100, 100, 100] kernel_nums = [100, 100, 100]
kernel_sizes = [3, 4, 5] kernel_sizes = [3, 4, 5]
embed_num = args["vocab_size"] vocab_size = args["vocab_size"]
embed_dim = 300 embed_dim = 300
pretrained_embed = None pretrained_embed = None
drop_prob = 0.5 drop_prob = 0.5
# no support for pre-trained embedding currently # no support for pre-trained embedding currently
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) self.embed = encoder.embedding.Embedding(vocab_size, embed_dim)
self.conv_pool = ConvMaxpool( self.conv_pool = encoder.conv_maxpool.ConvMaxpool(
in_channels=embed_dim, in_channels=embed_dim,
out_channels=kernel_nums, out_channels=kernel_nums,
kernel_sizes=kernel_sizes) kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(drop_prob) self.dropout = nn.Dropout(drop_prob)
self.fc = nn.Linear(sum(kernel_nums), class_num) self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
def forward(self, x): def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C] x = self.embed(x) # [N,L] -> [N,L,C]

View File

@ -2,8 +2,10 @@ from .embedding import Embedding
from .linear import Linear from .linear import Linear
from .lstm import Lstm from .lstm import Lstm
from .conv import Conv from .conv import Conv
from .conv_maxpool import ConvMaxpool
__all__ = ["Lstm", __all__ = ["Lstm",
"Embedding", "Embedding",
"Linear", "Linear",
"Conv"] "Conv",
"ConvMaxpool"]

View File

@ -4,6 +4,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
class ConvMaxpool(nn.Module): class ConvMaxpool(nn.Module):
@ -21,6 +22,7 @@ class ConvMaxpool(nn.Module):
if isinstance(kernel_sizes, int): if isinstance(kernel_sizes, int):
out_channels = [out_channels] out_channels = [out_channels]
kernel_sizes = [kernel_sizes] kernel_sizes = [kernel_sizes]
self.convs = nn.ModuleList([nn.Conv1d( self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,
out_channels=oc, out_channels=oc,
@ -31,6 +33,9 @@ class ConvMaxpool(nn.Module):
groups=groups, groups=groups,
bias=bias) bias=bias)
for oc, ks in zip(out_channels, kernel_sizes)]) for oc, ks in zip(out_channels, kernel_sizes)])
for conv in self.convs:
xavier_uniform_(conv.weight) # weight initialization
else: else:
raise Exception( raise Exception(
'Incorrect kernel sizes: should be list, tuple or int') 'Incorrect kernel sizes: should be list, tuple or int')