Prepare for CWS service:

- specify the name of the config file and the name of corresponding section where model init params store.
- fastnlp.py needs load_pickle to get dictionary size and the number of labels
- other minor adjustments
This commit is contained in:
FengZiYjun 2018-08-30 11:45:47 +08:00
parent 625b72691b
commit 9d6b0daa99
4 changed files with 31 additions and 16 deletions

View File

@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
def save_pickle(obj, pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f)
print("{} saved. ".format(file_name))
print("{} saved in {}.".format(file_name, pickle_path))
def load_pickle(pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f)
print("{} loaded. ".format(file_name))
print("{} loaded from {}.".format(file_name, pickle_path))
return obj

View File

@ -1,4 +1,5 @@
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
# from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
@ -11,7 +12,9 @@ Example:
"url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config", # the name of the config file which stores model initialization parameters
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
},
"text_class_model": {
"url": "www.fudan.edu.cn",
@ -25,13 +28,12 @@ FastNLP_MODEL_COLLECTION = {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config",
"config_section_name": "text_class_model"
}
}
CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"
class FastNLP(object):
"""
@ -56,10 +58,13 @@ class FastNLP(object):
self.model = None
self.infer_type = None # "seq_label"/"text_class"
def load(self, model_name):
def load(self, model_name, config_file="config", section_name="model"):
"""
Load a pre-trained FastNLP model together with additional data.
:param model_name: str, the name of a FastNLP model.
:param config_file: str, the name of the config file which stores the initialization information of the model.
(default: "config")
:param section_name: str, the name of the corresponding section in the config file. (default: model)
"""
assert type(model_name) is str
if model_name not in FastNLP_MODEL_COLLECTION:
@ -71,7 +76,13 @@ class FastNLP(object):
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
model_args = ConfigSection()
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args})
# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(word2index)
index2label = load_pickle(self.model_dir, "id2class.pkl")
model_args["num_classes"] = len(index2label)
# Construct the model
model = model_class(model_args)

View File

@ -1,4 +1,5 @@
import sys, os
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
@ -11,7 +12,6 @@ from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.optimizer import SGD
# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
@ -75,7 +75,7 @@ def train():
train_args["num_classes"] = p.num_classes
# Trainer
trainer = SeqLabelTrainer(train_args)
trainer = SeqLabelTrainer(**train_args.data)
# Model
model = AdvSeqLabel(train_args)

View File

@ -1,9 +1,13 @@
import sys
sys.path.append("..")
from fastNLP.fastnlp import FastNLP
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/data/save/"
def word_seg():
nlp = FastNLP("./data_for_tests/")
nlp.load("seq_label_model")
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
@ -20,4 +24,4 @@ def text_class():
if __name__ == "__main__":
text_class()
word_seg()