From 1d5bb0a3b6e36a1634e088593724770f383ad33f Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 19 Nov 2018 19:16:09 +0800 Subject: [PATCH] =?UTF-8?q?bug=20fix=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 3 ++- reproduction/CNN-sentence_classification/model.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 8375cf74..c8bd67e7 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -188,7 +188,8 @@ class DataSet(object): results.append(func(ins)) if new_field_name is not None: self.add_field(new_field_name, results) - return results + else: + return results if __name__ == '__main__': from fastNLP.core.instance import Instance diff --git a/reproduction/CNN-sentence_classification/model.py b/reproduction/CNN-sentence_classification/model.py index 125e7bcc..870e7c4e 100644 --- a/reproduction/CNN-sentence_classification/model.py +++ b/reproduction/CNN-sentence_classification/model.py @@ -4,8 +4,8 @@ import torch.nn.functional as F class CNN_text(nn.Module): - def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3, - batchsize=50, pretrained_embeddings=None): + def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5, L2_constrain=3, + pretrained_embeddings=None): super(CNN_text, self).__init__() self.embedding = nn.Embedding(embed_num, embed_dim) @@ -15,11 +15,11 @@ class CNN_text(nn.Module): # the network structure # Conv2d: input- N,C,H,W output- (50,100,62,1) - self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h]) - self.fc1 = nn.Linear(300, 2) + self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h]) + self.fc1 = nn.Linear(len(kernel_h)*kernel_num, num_classes) def max_pooling(self, x): - x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62) + x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62) x = F.max_pool1d(x, x.size(2)).squeeze(2) # x.size(2)=62 squeeze: (50,100,1) -> (50,100) return x @@ -33,3 +33,8 @@ class CNN_text(nn.Module): x = self.dropout(x) x = self.fc1(x) return x + +if __name__ == '__main__': + model = CNN_text(kernel_h=[1, 2, 3, 4],embed_num=3, embed_dim=2) + x = torch.LongTensor([[1, 2, 1, 2, 0]]) + print(model(x)) \ No newline at end of file