fastNLP/examples/readme_example.py
2018-09-14 14:19:39 +08:00

76 lines
2.3 KiB
Python

from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregation
from fastNLP.modules import decoder
from fastNLP.modules import encoder
class ClassificationModel(BaseModel):
"""
Simple text classification model based on CNN.
"""
def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.agg = aggregation.MaxPool()
self.dec = decoder.MLP(size_layer=[100, num_classes])
def forward(self, x):
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x
data_dir = 'save/' # directory to save data and model
train_path = './data_for_tests/text_classify.txt' # training set file
# load dataset
ds_loader = ClassDatasetLoader(train_path)
data = ds_loader.load()
# pre-process dataset
pre = ClassPreprocess()
train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir)
n_classes, vocab_size = pre.num_classes, pre.vocab_size
# construct model
model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
# construct trainer
train_args = {
"epochs": 3,
"batch_size": 16,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": None,
"use_cuda": True,
"loss": Loss("cross_entropy"),
"optimizer": Optimizer("Adam", lr=0.001)
}
trainer = ClassificationTrainer(**train_args)
# start training
trainer.train(model, train_data=train_set, dev_data=dev_set)
# predict using model
data_infer = [x[0] for x in data]
infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model.cpu(), data_infer)
print(labels_pred)