Merge pull request #46 from fastnlp/modify-readme-example

modify readme example
This commit is contained in:
Coet 2018-08-24 19:15:16 +08:00 committed by GitHub
commit 96391d6ab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 21 deletions

View File

@ -30,6 +30,7 @@ A typical fastNLP routine is composed of four phases: loading dataset, pre-proce
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules import encoder from fastNLP.modules import encoder
from fastNLP.modules import aggregation from fastNLP.modules import aggregation
from fastNLP.modules import decoder
from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.preprocess import ClassPreprocess from fastNLP.loader.preprocess import ClassPreprocess
@ -42,20 +43,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN. Simple text classification model based on CNN.
""" """
def __init__(self, class_num, vocab_size): def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__() super(ClassificationModel, self).__init__()
self.embed = encoder.Embedding(nums=vocab_size, dims=300) self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv( self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3) in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool() self.agg = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num) self.dec = decoder.MLP(100, num_classes=num_classes)
def forward(self, x): def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C] x = self.emb(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C] x = self.agg(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class] x = self.dec(x) # [N,C] -> [N, N_class]
return x return x
@ -75,7 +76,7 @@ model_args = {
'num_classes': n_classes, 'num_classes': n_classes,
'vocab_size': vocab_size 'vocab_size': vocab_size
} }
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
# train model # train model
train_args = { train_args = {

View File

@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregation from fastNLP.modules import aggregation
from fastNLP.modules import encoder from fastNLP.modules import encoder
from fastNLP.modules import decoder
class ClassificationModel(BaseModel): class ClassificationModel(BaseModel):
@ -20,20 +21,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN. Simple text classification model based on CNN.
""" """
def __init__(self, class_num, vocab_size): def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__() super(ClassificationModel, self).__init__()
self.embed = encoder.Embedding(nums=vocab_size, dims=300) self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv( self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3) in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool() self.agg = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num) self.dec = decoder.MLP(100, num_classes=num_classes)
def forward(self, x): def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C] x = self.emb(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C] x = self.agg(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class] x = self.dec(x) # [N,C] -> [N, N_class]
return x return x
@ -55,7 +56,7 @@ model_args = {
'num_classes': n_classes, 'num_classes': n_classes,
'vocab_size': vocab_size 'vocab_size': vocab_size
} }
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
# train model # train model
train_args = { train_args = {
@ -75,4 +76,4 @@ trainer.cross_validate(model)
# predict using model # predict using model
data_infer = [x[0] for x in data] data_infer = [x[0] for x in data]
infer = ClassificationInfer(data_dir) infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model, data_infer) labels_pred = infer.predict(model, data_infer)