mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
bug fix“
This commit is contained in:
parent
4149eb9c06
commit
1d5bb0a3b6
@ -188,7 +188,8 @@ class DataSet(object):
|
||||
results.append(func(ins))
|
||||
if new_field_name is not None:
|
||||
self.add_field(new_field_name, results)
|
||||
return results
|
||||
else:
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
from fastNLP.core.instance import Instance
|
||||
|
@ -4,8 +4,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class CNN_text(nn.Module):
|
||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3,
|
||||
batchsize=50, pretrained_embeddings=None):
|
||||
def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5, L2_constrain=3,
|
||||
pretrained_embeddings=None):
|
||||
super(CNN_text, self).__init__()
|
||||
|
||||
self.embedding = nn.Embedding(embed_num, embed_dim)
|
||||
@ -15,11 +15,11 @@ class CNN_text(nn.Module):
|
||||
|
||||
# the network structure
|
||||
# Conv2d: input- N,C,H,W output- (50,100,62,1)
|
||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h])
|
||||
self.fc1 = nn.Linear(300, 2)
|
||||
self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h])
|
||||
self.fc1 = nn.Linear(len(kernel_h)*kernel_num, num_classes)
|
||||
|
||||
def max_pooling(self, x):
|
||||
x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62)
|
||||
x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62)
|
||||
x = F.max_pool1d(x, x.size(2)).squeeze(2)
|
||||
# x.size(2)=62 squeeze: (50,100,1) -> (50,100)
|
||||
return x
|
||||
@ -33,3 +33,8 @@ class CNN_text(nn.Module):
|
||||
x = self.dropout(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = CNN_text(kernel_h=[1, 2, 3, 4],embed_num=3, embed_dim=2)
|
||||
x = torch.LongTensor([[1, 2, 1, 2, 0]])
|
||||
print(model(x))
|
Loading…
Reference in New Issue
Block a user