Merge pull request #46 from fastnlp/modify-readme-example

modify readme example
This commit is contained in:
Coet 2018-08-24 19:15:16 +08:00 committed by GitHub
commit 96391d6ab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 21 deletions

View File

@ -30,6 +30,7 @@ A typical fastNLP routine is composed of four phases: loading dataset, pre-proce
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import encoder
from fastNLP.modules import aggregation
from fastNLP.modules import decoder
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.preprocess import ClassPreprocess
@ -42,20 +43,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN.
"""
def __init__(self, class_num, vocab_size):
def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()
self.embed = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv(
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num)
self.agg = aggregation.MaxPool()
self.dec = decoder.MLP(100, num_classes=num_classes)
def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class]
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x
@ -75,7 +76,7 @@ model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
# train model
train_args = {

View File

@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregation
from fastNLP.modules import encoder
from fastNLP.modules import decoder
class ClassificationModel(BaseModel):
@ -20,20 +21,20 @@ class ClassificationModel(BaseModel):
Simple text classification model based on CNN.
"""
def __init__(self, class_num, vocab_size):
def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()
self.embed = encoder.Embedding(nums=vocab_size, dims=300)
self.conv = encoder.Conv(
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.pool = aggregation.MaxPool()
self.output = encoder.Linear(input_size=100, output_size=class_num)
self.agg = aggregation.MaxPool()
self.dec = decoder.MLP(100, num_classes=num_classes)
def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
x = self.pool(x) # [N,L,C] -> [N,C]
x = self.output(x) # [N,C] -> [N, N_class]
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x
@ -55,7 +56,7 @@ model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
# train model
train_args = {
@ -75,4 +76,4 @@ trainer.cross_validate(model)
# predict using model
data_infer = [x[0] for x in data]
infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model, data_infer)
labels_pred = infer.predict(model, data_infer)