add cws train script and corresponding config file

This commit is contained in:
FengZiYjun 2018-07-30 09:52:46 +08:00
parent 233e8328f7
commit ef8ec3b9e4
2 changed files with 139 additions and 0 deletions

View File

@ -0,0 +1,29 @@
[train]
epochs = 2
batch_size = 32
pickle_path = "./save/"
validate = true
save_best_dev = true
model_saved_path = "./save/"
rnn_hidden_units = 100
rnn_layers = 2
rnn_bi_direction = true
word_emb_dim = 100
dropout = 0.5
use_crf = true
use_cuda = true
[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 64
pickle_path = "./save/"
rnn_hidden_units = 100
rnn_layers = 1
rnn_bi_direction = true
word_emb_dim = 100
dropout = 0.5
use_crf = true
use_cuda = true

View File

@ -0,0 +1,110 @@
import sys
sys.path.append("..")
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import POSTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference
data_name = "pku_training.utf8"
cws_data_path = "/home/zyfeng/data/pku_training.utf8"
pickle_path = "./save/"
data_infer_path = "data_for_tests/people_infer.txt"
def infer():
# Load infer configuration, the same as test
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "id2class.pkl")
test_args["num_classes"] = len(index2label)
# Define the same model
model = SeqLabeling(test_args)
# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
print("model loaded!")
# Data Loader
raw_data_loader = BaseLoader(data_name, data_infer_path)
infer_data = raw_data_loader.load_lines()
# Inference interface
infer = Inference(pickle_path)
results = infer.predict(model, infer_data)
print(results)
print("Inference finished!")
def train():
# Config Loader
train_args = ConfigSection()
test_args = ConfigSection()
ConfigLoader("good_name", "good_path").load_config("./cws.cfg", {"train": train_args, "test": test_args})
# Data Loader
loader = TokenizeDatasetLoader(data_name, cws_data_path)
train_data = loader.load_pku()
# Preprocessor
p = POSPreprocess(train_data, pickle_path)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes
# Trainer
trainer = POSTrainer(train_args)
# Model
model = SeqLabeling(train_args)
# Start training
trainer.train(model)
print("Training finished!")
# Saver
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")
def test():
# Config Loader
train_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
# Define the same model
model = SeqLabeling(train_args)
# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
print("model loaded!")
# Load test configuration
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
# Tester
tester = POSTester(test_args)
# Start testing
tester.test(model)
# print test results
print(tester.show_matrices())
print("model tested!")
if __name__ == "__main__":
train()