diff --git a/reproduction/chinese_word_seg/cws.cfg b/reproduction/chinese_word_seg/cws.cfg new file mode 100644 index 00000000..ded4f623 --- /dev/null +++ b/reproduction/chinese_word_seg/cws.cfg @@ -0,0 +1,29 @@ +[train] +epochs = 2 +batch_size = 32 +pickle_path = "./save/" +validate = true +save_best_dev = true +model_saved_path = "./save/" +rnn_hidden_units = 100 +rnn_layers = 2 +rnn_bi_direction = true +word_emb_dim = 100 +dropout = 0.5 +use_crf = true +use_cuda = true + +[test] +save_output = true +validate_in_training = true +save_dev_input = false +save_loss = true +batch_size = 64 +pickle_path = "./save/" +rnn_hidden_units = 100 +rnn_layers = 1 +rnn_bi_direction = true +word_emb_dim = 100 +dropout = 0.5 +use_crf = true +use_cuda = true diff --git a/reproduction/chinese_word_seg/cws_train.py b/reproduction/chinese_word_seg/cws_train.py new file mode 100644 index 00000000..6616ff5f --- /dev/null +++ b/reproduction/chinese_word_seg/cws_train.py @@ -0,0 +1,110 @@ +import sys + +sys.path.append("..") + +from fastNLP.loader.config_loader import ConfigLoader, ConfigSection +from fastNLP.core.trainer import POSTrainer +from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader +from fastNLP.loader.preprocess import POSPreprocess, load_pickle +from fastNLP.saver.model_saver import ModelSaver +from fastNLP.loader.model_loader import ModelLoader +from fastNLP.core.tester import POSTester +from fastNLP.models.sequence_modeling import SeqLabeling +from fastNLP.core.inference import Inference + +data_name = "pku_training.utf8" +cws_data_path = "/home/zyfeng/data/pku_training.utf8" +pickle_path = "./save/" +data_infer_path = "data_for_tests/people_infer.txt" + + +def infer(): + # Load infer configuration, the same as test + test_args = ConfigSection() + ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) + + # fetch dictionary size and number of labels from pickle files + word2index = load_pickle(pickle_path, "word2id.pkl") + test_args["vocab_size"] = len(word2index) + index2label = load_pickle(pickle_path, "id2class.pkl") + test_args["num_classes"] = len(index2label) + + # Define the same model + model = SeqLabeling(test_args) + + # Dump trained parameters into the model + ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") + print("model loaded!") + + # Data Loader + raw_data_loader = BaseLoader(data_name, data_infer_path) + infer_data = raw_data_loader.load_lines() + + # Inference interface + infer = Inference(pickle_path) + results = infer.predict(model, infer_data) + + print(results) + print("Inference finished!") + + +def train(): + # Config Loader + train_args = ConfigSection() + test_args = ConfigSection() + ConfigLoader("good_name", "good_path").load_config("./cws.cfg", {"train": train_args, "test": test_args}) + + # Data Loader + loader = TokenizeDatasetLoader(data_name, cws_data_path) + train_data = loader.load_pku() + + # Preprocessor + p = POSPreprocess(train_data, pickle_path) + train_args["vocab_size"] = p.vocab_size + train_args["num_classes"] = p.num_classes + + # Trainer + trainer = POSTrainer(train_args) + + # Model + model = SeqLabeling(train_args) + + # Start training + trainer.train(model) + print("Training finished!") + + # Saver + saver = ModelSaver("./save/saved_model.pkl") + saver.save_pytorch(model) + print("Model saved!") + + +def test(): + # Config Loader + train_args = ConfigSection() + ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) + + # Define the same model + model = SeqLabeling(train_args) + + # Dump trained parameters into the model + ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") + print("model loaded!") + + # Load test configuration + test_args = ConfigSection() + ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) + + # Tester + tester = POSTester(test_args) + + # Start testing + tester.test(model) + + # print test results + print(tester.show_matrices()) + print("model tested!") + + +if __name__ == "__main__": + train()