mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
Merge pull request #46 from fastnlp/modify-readme-example
modify readme example
This commit is contained in:
commit
96391d6ab3
21
README.md
21
README.md
@ -30,6 +30,7 @@ A typical fastNLP routine is composed of four phases: loading dataset, pre-proce
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import encoder
|
||||
from fastNLP.modules import aggregation
|
||||
from fastNLP.modules import decoder
|
||||
|
||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader
|
||||
from fastNLP.loader.preprocess import ClassPreprocess
|
||||
@ -42,20 +43,20 @@ class ClassificationModel(BaseModel):
|
||||
Simple text classification model based on CNN.
|
||||
"""
|
||||
|
||||
def __init__(self, class_num, vocab_size):
|
||||
def __init__(self, num_classes, vocab_size):
|
||||
super(ClassificationModel, self).__init__()
|
||||
|
||||
self.embed = encoder.Embedding(nums=vocab_size, dims=300)
|
||||
self.conv = encoder.Conv(
|
||||
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
|
||||
self.enc = encoder.Conv(
|
||||
in_channels=300, out_channels=100, kernel_size=3)
|
||||
self.pool = aggregation.MaxPool()
|
||||
self.output = encoder.Linear(input_size=100, output_size=class_num)
|
||||
self.agg = aggregation.MaxPool()
|
||||
self.dec = decoder.MLP(100, num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x) # [N,L] -> [N,L,C]
|
||||
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
|
||||
x = self.pool(x) # [N,L,C] -> [N,C]
|
||||
x = self.output(x) # [N,C] -> [N, N_class]
|
||||
x = self.emb(x) # [N,L] -> [N,L,C]
|
||||
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
|
||||
x = self.agg(x) # [N,L,C] -> [N,C]
|
||||
x = self.dec(x) # [N,C] -> [N, N_class]
|
||||
return x
|
||||
|
||||
|
||||
@ -75,7 +76,7 @@ model_args = {
|
||||
'num_classes': n_classes,
|
||||
'vocab_size': vocab_size
|
||||
}
|
||||
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
|
||||
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
|
||||
|
||||
# train model
|
||||
train_args = {
|
||||
|
@ -13,6 +13,7 @@ from fastNLP.loader.dataset_loader import ClassDatasetLoader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import aggregation
|
||||
from fastNLP.modules import encoder
|
||||
from fastNLP.modules import decoder
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
@ -20,20 +21,20 @@ class ClassificationModel(BaseModel):
|
||||
Simple text classification model based on CNN.
|
||||
"""
|
||||
|
||||
def __init__(self, class_num, vocab_size):
|
||||
def __init__(self, num_classes, vocab_size):
|
||||
super(ClassificationModel, self).__init__()
|
||||
|
||||
self.embed = encoder.Embedding(nums=vocab_size, dims=300)
|
||||
self.conv = encoder.Conv(
|
||||
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
|
||||
self.enc = encoder.Conv(
|
||||
in_channels=300, out_channels=100, kernel_size=3)
|
||||
self.pool = aggregation.MaxPool()
|
||||
self.output = encoder.Linear(input_size=100, output_size=class_num)
|
||||
self.agg = aggregation.MaxPool()
|
||||
self.dec = decoder.MLP(100, num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x) # [N,L] -> [N,L,C]
|
||||
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out]
|
||||
x = self.pool(x) # [N,L,C] -> [N,C]
|
||||
x = self.output(x) # [N,C] -> [N, N_class]
|
||||
x = self.emb(x) # [N,L] -> [N,L,C]
|
||||
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
|
||||
x = self.agg(x) # [N,L,C] -> [N,C]
|
||||
x = self.dec(x) # [N,C] -> [N, N_class]
|
||||
return x
|
||||
|
||||
|
||||
@ -55,7 +56,7 @@ model_args = {
|
||||
'num_classes': n_classes,
|
||||
'vocab_size': vocab_size
|
||||
}
|
||||
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size)
|
||||
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
|
||||
|
||||
# train model
|
||||
train_args = {
|
||||
@ -75,4 +76,4 @@ trainer.cross_validate(model)
|
||||
# predict using model
|
||||
data_infer = [x[0] for x in data]
|
||||
infer = ClassificationInfer(data_dir)
|
||||
labels_pred = infer.predict(model, data_infer)
|
||||
labels_pred = infer.predict(model, data_infer)
|
Loading…
Reference in New Issue
Block a user