mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
Merge pull request #32 from choosewhatulike/master
add a new chinese word segmentation model
This commit is contained in:
commit
2f83010d9d
34
reproduction/chinese_word_segment/cws.cfg
Normal file
34
reproduction/chinese_word_segment/cws.cfg
Normal file
@ -0,0 +1,34 @@
|
||||
[train]
|
||||
epochs = 30
|
||||
batch_size = 64
|
||||
pickle_path = "./save/"
|
||||
validate = true
|
||||
save_best_dev = true
|
||||
model_saved_path = "./save/"
|
||||
rnn_hidden_units = 100
|
||||
word_emb_dim = 100
|
||||
use_crf = true
|
||||
use_cuda = true
|
||||
|
||||
[test]
|
||||
save_output = true
|
||||
validate_in_training = true
|
||||
save_dev_input = false
|
||||
save_loss = true
|
||||
batch_size = 640
|
||||
pickle_path = "./save/"
|
||||
use_crf = true
|
||||
use_cuda = true
|
||||
|
||||
|
||||
[POS_test]
|
||||
save_output = true
|
||||
validate_in_training = true
|
||||
save_dev_input = false
|
||||
save_loss = true
|
||||
batch_size = 640
|
||||
pickle_path = "./save/"
|
||||
use_crf = true
|
||||
use_cuda = true
|
||||
rnn_hidden_units = 100
|
||||
word_emb_dim = 100
|
140
reproduction/chinese_word_segment/run.py
Normal file
140
reproduction/chinese_word_segment/run.py
Normal file
@ -0,0 +1,140 @@
|
||||
import sys, os
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
|
||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.core.inference import SeqLabelInfer
|
||||
from fastNLP.core.optimizer import SGD
|
||||
|
||||
# not in the file's dir
|
||||
if len(os.path.dirname(__file__)) != 0:
|
||||
os.chdir(os.path.dirname(__file__))
|
||||
datadir = 'icwb2-data'
|
||||
cfgfile = 'cws.cfg'
|
||||
data_name = "pku_training.utf8"
|
||||
|
||||
cws_data_path = os.path.join(datadir, "training/pku_training.utf8")
|
||||
pickle_path = "save"
|
||||
data_infer_path = os.path.join(datadir, "infer.utf8")
|
||||
|
||||
def infer():
|
||||
# Config Loader
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
|
||||
# Define the same model
|
||||
model = AdvSeqLabel(test_args)
|
||||
|
||||
try:
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
print('model loaded!')
|
||||
except Exception as e:
|
||||
print('cannot load model!')
|
||||
raise
|
||||
|
||||
# Data Loader
|
||||
raw_data_loader = BaseLoader(data_name, data_infer_path)
|
||||
infer_data = raw_data_loader.load_lines()
|
||||
print('data loaded')
|
||||
|
||||
# Inference interface
|
||||
infer = SeqLabelInfer(pickle_path)
|
||||
results = infer.predict(model, infer_data)
|
||||
|
||||
print(results)
|
||||
print("Inference finished!")
|
||||
|
||||
|
||||
def train():
|
||||
# Config Loader
|
||||
train_args = ConfigSection()
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("good_name", "good_path").load_config(cfgfile, {"train": train_args, "test": test_args})
|
||||
|
||||
# Data Loader
|
||||
loader = TokenizeDatasetLoader(data_name, cws_data_path)
|
||||
train_data = loader.load_pku()
|
||||
|
||||
# Preprocessor
|
||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3)
|
||||
train_args["vocab_size"] = p.vocab_size
|
||||
train_args["num_classes"] = p.num_classes
|
||||
|
||||
# Trainer
|
||||
trainer = SeqLabelTrainer(train_args)
|
||||
|
||||
# Model
|
||||
model = AdvSeqLabel(train_args)
|
||||
try:
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
print('model parameter loaded!')
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Start training
|
||||
trainer.train(model)
|
||||
print("Training finished!")
|
||||
|
||||
# Saver
|
||||
saver = ModelSaver("./save/saved_model.pkl")
|
||||
saver.save_pytorch(model)
|
||||
print("Model saved!")
|
||||
|
||||
|
||||
def test():
|
||||
# Config Loader
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
model = AdvSeqLabel(test_args)
|
||||
|
||||
# Dump trained parameters into the model
|
||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
|
||||
print("model loaded!")
|
||||
|
||||
# Tester
|
||||
tester = SeqLabelTester(test_args)
|
||||
|
||||
# Start testing
|
||||
tester.test(model)
|
||||
|
||||
# print test results
|
||||
print(tester.show_matrices())
|
||||
print("model tested!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
|
||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
|
||||
args = parser.parse_args()
|
||||
if args.mode == 'train':
|
||||
train()
|
||||
elif args.mode == 'test':
|
||||
test()
|
||||
elif args.mode == 'infer':
|
||||
infer()
|
||||
else:
|
||||
print('no mode specified for model!')
|
||||
parser.print_help()
|
Loading…
Reference in New Issue
Block a user