mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
Prepare for CWS service:
- specify the name of the config file and the name of corresponding section where model init params store. - fastnlp.py needs load_pickle to get dictionary size and the number of labels - other minor adjustments
This commit is contained in:
parent
625b72691b
commit
9d6b0daa99
@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
def save_pickle(obj, pickle_path, file_name):
|
||||
with open(os.path.join(pickle_path, file_name), "wb") as f:
|
||||
_pickle.dump(obj, f)
|
||||
print("{} saved. ".format(file_name))
|
||||
print("{} saved in {}.".format(file_name, pickle_path))
|
||||
|
||||
|
||||
def load_pickle(pickle_path, file_name):
|
||||
with open(os.path.join(pickle_path, file_name), "rb") as f:
|
||||
obj = _pickle.load(f)
|
||||
print("{} loaded. ".format(file_name))
|
||||
print("{} loaded from {}.".format(file_name, pickle_path))
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
# from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
|
||||
@ -11,7 +12,9 @@ Example:
|
||||
"url": "www.fudan.edu.cn",
|
||||
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
|
||||
"pickle": "seq_label_model.pkl",
|
||||
"type": "seq_label"
|
||||
"type": "seq_label",
|
||||
"config_file_name": "config", # the name of the config file which stores model initialization parameters
|
||||
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
|
||||
},
|
||||
"text_class_model": {
|
||||
"url": "www.fudan.edu.cn",
|
||||
@ -25,13 +28,12 @@ FastNLP_MODEL_COLLECTION = {
|
||||
"url": "",
|
||||
"class": "sequence_modeling.AdvSeqLabel",
|
||||
"pickle": "cws_basic_model_v_0.pkl",
|
||||
"type": "seq_label"
|
||||
"type": "seq_label",
|
||||
"config_file_name": "config",
|
||||
"config_section_name": "text_class_model"
|
||||
}
|
||||
}
|
||||
|
||||
CONFIG_FILE_NAME = "config"
|
||||
SECTION_NAME = "text_class_model"
|
||||
|
||||
|
||||
class FastNLP(object):
|
||||
"""
|
||||
@ -56,10 +58,13 @@ class FastNLP(object):
|
||||
self.model = None
|
||||
self.infer_type = None # "seq_label"/"text_class"
|
||||
|
||||
def load(self, model_name):
|
||||
def load(self, model_name, config_file="config", section_name="model"):
|
||||
"""
|
||||
Load a pre-trained FastNLP model together with additional data.
|
||||
:param model_name: str, the name of a FastNLP model.
|
||||
:param config_file: str, the name of the config file which stores the initialization information of the model.
|
||||
(default: "config")
|
||||
:param section_name: str, the name of the corresponding section in the config file. (default: model)
|
||||
"""
|
||||
assert type(model_name) is str
|
||||
if model_name not in FastNLP_MODEL_COLLECTION:
|
||||
@ -71,7 +76,13 @@ class FastNLP(object):
|
||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
|
||||
|
||||
model_args = ConfigSection()
|
||||
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})
|
||||
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args})
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(self.model_dir, "id2class.pkl")
|
||||
model_args["num_classes"] = len(index2label)
|
||||
|
||||
# Construct the model
|
||||
model = model_class(model_args)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import sys, os
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
@ -11,7 +12,6 @@ from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.core.inference import SeqLabelInfer
|
||||
from fastNLP.core.optimizer import SGD
|
||||
|
||||
# not in the file's dir
|
||||
if len(os.path.dirname(__file__)) != 0:
|
||||
@ -75,7 +75,7 @@ def train():
|
||||
train_args["num_classes"] = p.num_classes
|
||||
|
||||
# Trainer
|
||||
trainer = SeqLabelTrainer(train_args)
|
||||
trainer = SeqLabelTrainer(**train_args.data)
|
||||
|
||||
# Model
|
||||
model = AdvSeqLabel(train_args)
|
||||
|
@ -1,9 +1,13 @@
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from fastNLP.fastnlp import FastNLP
|
||||
|
||||
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/data/save/"
|
||||
|
||||
def word_seg():
|
||||
nlp = FastNLP("./data_for_tests/")
|
||||
nlp.load("seq_label_model")
|
||||
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
|
||||
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
|
||||
text = "这是最好的基于深度学习的中文分词系统。"
|
||||
result = nlp.run(text)
|
||||
print(result)
|
||||
@ -20,4 +24,4 @@ def text_class():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_class()
|
||||
word_seg()
|
||||
|
Loading…
Reference in New Issue
Block a user