mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
* delete readme_example.py because it is oooooooout of date.
* rename preprocess.py into utils.py, because nothing about preprocess in it * anything in loader/ and saver/ is moved directly into io/ * corresponding unit tests are moved to /test/io * delete fastnlp.py, because we have new and better APIs * rename Biaffine_parser/run_test.py to Biaffine_parser/main.py; Otherwise, test will fail. * A looooooooooot of ancient codes to be refined...........
This commit is contained in:
parent
b6a0d33cb1
commit
e9d7074ba1
@ -1,75 +0,0 @@
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.predictor import ClassificationInfer
|
||||
from fastNLP.core.preprocess import ClassPreprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import aggregator
|
||||
from fastNLP.modules import decoder
|
||||
from fastNLP.modules import encoder
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
"""
|
||||
Simple text classification model based on CNN.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, vocab_size):
|
||||
super(ClassificationModel, self).__init__()
|
||||
|
||||
self.emb = encoder.Embedding(nums=vocab_size, dims=300)
|
||||
self.enc = encoder.Conv(
|
||||
in_channels=300, out_channels=100, kernel_size=3)
|
||||
self.agg = aggregator.MaxPool()
|
||||
self.dec = decoder.MLP(size_layer=[100, num_classes])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.emb(x) # [N,L] -> [N,L,C]
|
||||
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
|
||||
x = self.agg(x) # [N,L,C] -> [N,C]
|
||||
x = self.dec(x) # [N,C] -> [N, N_class]
|
||||
return x
|
||||
|
||||
|
||||
data_dir = 'save/' # directory to save data and model
|
||||
train_path = './data_for_tests/text_classify.txt' # training set file
|
||||
|
||||
# load dataset
|
||||
ds_loader = ClassDataSetLoader()
|
||||
data = ds_loader.load()
|
||||
|
||||
# pre-process dataset
|
||||
pre = ClassPreprocess()
|
||||
train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir)
|
||||
n_classes, vocab_size = pre.num_classes, pre.vocab_size
|
||||
|
||||
# construct model
|
||||
model_args = {
|
||||
'num_classes': n_classes,
|
||||
'vocab_size': vocab_size
|
||||
}
|
||||
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
|
||||
|
||||
# construct trainer
|
||||
train_args = {
|
||||
"epochs": 3,
|
||||
"batch_size": 16,
|
||||
"pickle_path": data_dir,
|
||||
"validate": False,
|
||||
"save_best_dev": False,
|
||||
"model_saved_path": None,
|
||||
"use_cuda": True,
|
||||
"loss": Loss("cross_entropy"),
|
||||
"optimizer": Optimizer("Adam", lr=0.001)
|
||||
}
|
||||
trainer = ClassificationTrainer(**train_args)
|
||||
|
||||
# start training
|
||||
trainer.train(model, train_data=train_set, dev_data=dev_set)
|
||||
|
||||
# predict using model
|
||||
data_infer = [x[0] for x in data]
|
||||
infer = ClassificationInfer(data_dir)
|
||||
labels_pred = infer.predict(model.cpu(), data_infer)
|
||||
print(labels_pred)
|
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
import os
|
||||
|
||||
@ -17,7 +19,6 @@ from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator2
|
||||
from fastNLP.core.tester import Tester
|
||||
|
||||
|
||||
model_urls = {
|
||||
}
|
||||
|
||||
@ -228,7 +229,7 @@ class Parser(API):
|
||||
elif p.field_name == 'pos_list':
|
||||
p.field_name = 'gold_pos'
|
||||
pp(ds)
|
||||
head_cor, label_cor, total = 0,0,0
|
||||
head_cor, label_cor, total = 0, 0, 0
|
||||
for ins in ds:
|
||||
head_gold = ins['gold_heads']
|
||||
head_pred = ins['heads']
|
||||
@ -236,7 +237,7 @@ class Parser(API):
|
||||
total += length
|
||||
for i in range(length):
|
||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0
|
||||
uas = head_cor/total
|
||||
uas = head_cor / total
|
||||
print('uas:{:.2f}'.format(uas))
|
||||
|
||||
for p in pp:
|
||||
@ -247,25 +248,34 @@ class Parser(API):
|
||||
|
||||
return uas
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
|
||||
pos = POS(device='cpu')
|
||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
|
||||
print(pos.predict(s))
|
||||
|
||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
|
||||
cws = CWS(device='cuda:0')
|
||||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
if __name__ == "__main__":
|
||||
# 以下路径在102
|
||||
"""
|
||||
pos_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/pos_crf-5e26d3b0.pkl'
|
||||
pos = POS(model_path=pos_model_path, device='cpu')
|
||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
|
||||
#print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
|
||||
print(pos.predict(s))
|
||||
"""
|
||||
|
||||
"""
|
||||
cws_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/cws_crf-5a8a3e66.pkl'
|
||||
cws = CWS(model_path=cws_model_path, device='cuda:0')
|
||||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
#print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
|
||||
cws.predict(s)
|
||||
parser = Parser(device='cuda:0')
|
||||
print(parser.test('../../reproduction/Biaffine_parser/test.conll'))
|
||||
"""
|
||||
|
||||
parser_model_path = "/home/hyan/fastNLP_models/upload-demo/upload/parser-d57cd5fc.pkl"
|
||||
parser = Parser(model_path=parser_model_path, device='cuda:0')
|
||||
# print(parser.test('../../reproduction/Biaffine_parser/test.conll'))
|
||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
print(parser.predict(s))
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Field(object):
|
||||
@ -30,6 +29,7 @@ class Field(object):
|
||||
def __repr__(self):
|
||||
return self.content.__repr__()
|
||||
|
||||
|
||||
class TextField(Field):
|
||||
def __init__(self, text, is_target):
|
||||
"""
|
||||
@ -43,6 +43,7 @@ class LabelField(Field):
|
||||
"""The Field representing a single label. Can be a string or integer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, label, is_target=True):
|
||||
super(LabelField, self).__init__(label, is_target)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FieldArray(object):
|
||||
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
|
||||
self.name = name
|
||||
@ -10,7 +10,7 @@ class FieldArray(object):
|
||||
self.need_tensor = need_tensor
|
||||
|
||||
def __repr__(self):
|
||||
#TODO
|
||||
# TODO
|
||||
return '{}: {}'.format(self.name, self.content.__repr__())
|
||||
|
||||
def append(self, val):
|
||||
|
@ -50,20 +50,6 @@ class Predictor(object):
|
||||
return y
|
||||
|
||||
|
||||
class SeqLabelInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
|
||||
super(SeqLabelInfer, self).__init__()
|
||||
|
||||
|
||||
class ClassificationInfer(Predictor):
|
||||
def __init__(self, pickle_path):
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
|
||||
super(ClassificationInfer, self).__init__()
|
||||
|
||||
|
||||
def seq_label_post_processor(batch_outputs, label_vocab):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
|
@ -1,6 +1,8 @@
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from itertools import chain
|
||||
|
||||
|
||||
def convert_to_torch_tensor(data_list, use_cuda):
|
||||
"""Convert lists into (cuda) Tensors.
|
||||
@ -43,6 +45,7 @@ class RandomSampler(BaseSampler):
|
||||
def __call__(self, data_set):
|
||||
return list(np.random.permutation(len(data_set)))
|
||||
|
||||
|
||||
class BucketSampler(BaseSampler):
|
||||
|
||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'):
|
||||
@ -56,14 +59,14 @@ class BucketSampler(BaseSampler):
|
||||
total_sample_num = len(seq_lens)
|
||||
|
||||
bucket_indexes = []
|
||||
num_sample_per_bucket = total_sample_num//self.num_buckets
|
||||
num_sample_per_bucket = total_sample_num // self.num_buckets
|
||||
for i in range(self.num_buckets):
|
||||
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
|
||||
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
|
||||
bucket_indexes[-1][1] = total_sample_num
|
||||
|
||||
sorted_seq_lens = list(sorted([(idx, seq_len) for
|
||||
idx, seq_len in zip(range(total_sample_num), seq_lens)],
|
||||
key=lambda x:x[1]))
|
||||
key=lambda x: x[1]))
|
||||
|
||||
batchs = []
|
||||
|
||||
@ -73,19 +76,18 @@ class BucketSampler(BaseSampler):
|
||||
end_idx = bucket_indexes[b_idx][1]
|
||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
|
||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
|
||||
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
|
||||
num_batch_per_bucket = len(left_init_indexes) // self.batch_size
|
||||
np.random.shuffle(left_init_indexes)
|
||||
for i in range(num_batch_per_bucket):
|
||||
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
|
||||
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]
|
||||
if (left_init_indexes)!=0:
|
||||
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
|
||||
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
|
||||
if (left_init_indexes) != 0:
|
||||
batchs.append(left_init_indexes)
|
||||
np.random.shuffle(batchs)
|
||||
|
||||
return list(chain(*batchs))
|
||||
|
||||
|
||||
|
||||
def simple_sort_bucketing(lengths):
|
||||
"""
|
||||
|
||||
@ -105,6 +107,7 @@ def simple_sort_bucketing(lengths):
|
||||
# TODO: need to return buckets
|
||||
return [idx for idx, _ in sorted_lengths]
|
||||
|
||||
|
||||
def k_means_1d(x, k, max_iter=100):
|
||||
"""Perform k-means on 1-D data.
|
||||
|
||||
@ -159,4 +162,3 @@ def k_means_bucketing(lengths, buckets):
|
||||
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
|
||||
bucket_data[bucket_id].append(idx)
|
||||
return bucket_data
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.saver.logger import create_logger
|
||||
from fastNLP.io.logger import create_logger
|
||||
|
||||
logger = create_logger(__name__, "./train_test.log")
|
||||
|
||||
@ -119,24 +120,3 @@ class Tester(object):
|
||||
|
||||
"""
|
||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
|
||||
|
||||
|
||||
class SeqLabelTester(Tester):
|
||||
def __init__(self, **test_args):
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.")
|
||||
super(SeqLabelTester, self).__init__(**test_args)
|
||||
|
||||
|
||||
class ClassificationTester(Tester):
|
||||
def __init__(self, **test_args):
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
|
||||
super(ClassificationTester, self).__init__(**test_args)
|
||||
|
||||
|
||||
class SNLITester(Tester):
|
||||
def __init__(self, **test_args):
|
||||
print(
|
||||
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.")
|
||||
super(SNLITester, self).__init__(**test_args)
|
||||
|
@ -9,11 +9,10 @@ from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.sampler import BucketSampler
|
||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.saver.logger import create_logger
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.io.logger import create_logger
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
|
||||
logger = create_logger(__name__, "./train_test.log")
|
||||
logger.disabled = True
|
||||
@ -182,19 +181,10 @@ class Trainer(object):
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
|
||||
for name, param in self._model.named_parameters():
|
||||
if param.requires_grad:
|
||||
<<<<<<< HEAD
|
||||
# self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step)
|
||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step)
|
||||
pass
|
||||
|
||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
|
||||
=======
|
||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
||||
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
|
||||
>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder
|
||||
end = time.time()
|
||||
diff = timedelta(seconds=round(end - kwargs["start"]))
|
||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
|
||||
@ -339,40 +329,3 @@ class Trainer(object):
|
||||
def set_validator(self, validor):
|
||||
self.validator = validor
|
||||
|
||||
|
||||
class SeqLabelTrainer(Trainer):
|
||||
"""Trainer for Sequence Labeling
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
print(
|
||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.")
|
||||
super(SeqLabelTrainer, self).__init__(**kwargs)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
return SeqLabelTester(**valid_args)
|
||||
|
||||
|
||||
class ClassificationTrainer(Trainer):
|
||||
"""Trainer for text classification."""
|
||||
|
||||
def __init__(self, **train_args):
|
||||
print(
|
||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.")
|
||||
super(ClassificationTrainer, self).__init__(**train_args)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
return ClassificationTester(**valid_args)
|
||||
|
||||
|
||||
class SNLITrainer(Trainer):
|
||||
"""Trainer for text SNLI."""
|
||||
|
||||
def __init__(self, **train_args):
|
||||
print(
|
||||
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.")
|
||||
super(SNLITrainer, self).__init__(**train_args)
|
||||
|
||||
def _create_validator(self, valid_args):
|
||||
return SNLITester(**valid_args)
|
||||
|
@ -2,8 +2,6 @@ import _pickle
|
||||
import os
|
||||
|
||||
|
||||
# the first vocab in dict with the index = 5
|
||||
|
||||
def save_pickle(obj, pickle_path, file_name):
|
||||
"""Save an object into a pickle file.
|
||||
|
@ -13,7 +13,7 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
|
||||
def isiterable(p_object):
|
||||
try:
|
||||
it = iter(p_object)
|
||||
_ = iter(p_object)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
@ -1,343 +0,0 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.loader.dataset_loader import convert_seq_dataset
|
||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
|
||||
"""
|
||||
mapping from model name to [URL, file_name.class_name, model_pickle_name]
|
||||
Notice that the class of the model should be in "models" directory.
|
||||
|
||||
Example:
|
||||
"seq_label_model": {
|
||||
"url": "www.fudan.edu.cn",
|
||||
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
|
||||
"pickle": "seq_label_model.pkl",
|
||||
"type": "seq_label",
|
||||
"config_file_name": "config", # the name of the config file which stores model initialization parameters
|
||||
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
|
||||
},
|
||||
"text_class_model": {
|
||||
"url": "www.fudan.edu.cn",
|
||||
"class": "cnn_text_classification.CNNText",
|
||||
"pickle": "text_class_model.pkl",
|
||||
"type": "text_class"
|
||||
}
|
||||
"""
|
||||
FastNLP_MODEL_COLLECTION = {
|
||||
"cws_basic_model": {
|
||||
"url": "",
|
||||
"class": "sequence_modeling.AdvSeqLabel",
|
||||
"pickle": "cws_basic_model_v_0.pkl",
|
||||
"type": "seq_label",
|
||||
"config_file_name": "cws.cfg",
|
||||
"config_section_name": "text_class_model"
|
||||
},
|
||||
"pos_tag_model": {
|
||||
"url": "",
|
||||
"class": "sequence_modeling.AdvSeqLabel",
|
||||
"pickle": "pos_tag_model_v_0.pkl",
|
||||
"type": "seq_label",
|
||||
"config_file_name": "pos_tag.cfg",
|
||||
"config_section_name": "pos_tag_model"
|
||||
},
|
||||
"text_classify_model": {
|
||||
"url": "",
|
||||
"class": "cnn_text_classification.CNNText",
|
||||
"pickle": "text_class_model_v0.pkl",
|
||||
"type": "text_class",
|
||||
"config_file_name": "text_classify.cfg",
|
||||
"config_section_name": "model"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class FastNLP(object):
|
||||
"""
|
||||
High-level interface for direct model inference.
|
||||
Example Usage
|
||||
::
|
||||
fastnlp = FastNLP()
|
||||
fastnlp.load("zh_pos_tag_model")
|
||||
text = "这是最好的基于深度学习的中文分词系统。"
|
||||
result = fastnlp.run(text)
|
||||
print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir="./"):
|
||||
"""
|
||||
:param model_dir: this directory should contain the following files:
|
||||
1. a trained model
|
||||
2. a config file, which is a fastNLP's configuration.
|
||||
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.infer_type = None # "seq_label"/"text_class"
|
||||
self.word_vocab = None
|
||||
self.label_vocab = None
|
||||
|
||||
def load(self, model_name, config_file="config", section_name="model"):
|
||||
"""
|
||||
Load a pre-trained FastNLP model together with additional data.
|
||||
:param model_name: str, the name of a FastNLP model.
|
||||
:param config_file: str, the name of the config file which stores the initialization information of the model.
|
||||
(default: "config")
|
||||
:param section_name: str, the name of the corresponding section in the config file. (default: model)
|
||||
"""
|
||||
assert type(model_name) is str
|
||||
if model_name not in FastNLP_MODEL_COLLECTION:
|
||||
raise ValueError("No FastNLP model named {}.".format(model_name))
|
||||
|
||||
if not self.model_exist(model_dir=self.model_dir):
|
||||
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"])
|
||||
|
||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
|
||||
print("Restore model class {}".format(str(model_class)))
|
||||
|
||||
model_args = ConfigSection()
|
||||
ConfigLoader.load_config(os.path.join(self.model_dir, config_file), {section_name: model_args})
|
||||
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(self.word_vocab)
|
||||
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
|
||||
model_args["num_classes"] = len(self.label_vocab)
|
||||
|
||||
# Construct the model
|
||||
model = model_class(model_args)
|
||||
print("Model constructed.")
|
||||
|
||||
# To do: framework independent
|
||||
ModelLoader.load_pytorch(model, os.path.join(self.model_dir, FastNLP_MODEL_COLLECTION[model_name]["pickle"]))
|
||||
print("Model weights loaded.")
|
||||
|
||||
self.model = model
|
||||
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"]
|
||||
|
||||
print("Inference ready.")
|
||||
|
||||
def run(self, raw_input):
|
||||
"""
|
||||
Perform inference over given input using the loaded model.
|
||||
:param raw_input: list of string. Each list is an input query.
|
||||
:return results:
|
||||
"""
|
||||
|
||||
infer = self._create_inference(self.model_dir)
|
||||
|
||||
# tokenize: list of string ---> 2-D list of string
|
||||
infer_input = self.tokenize(raw_input, language="zh")
|
||||
|
||||
# create DataSet: 2-D list of strings ----> DataSet
|
||||
infer_data = self._create_data_set(infer_input)
|
||||
|
||||
# DataSet ---> 2-D list of tags
|
||||
results = infer.predict(self.model, infer_data)
|
||||
|
||||
# 2-D list of tags ---> list of final answers
|
||||
outputs = self._make_output(results, infer_input)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _get_model_class(file_class_name):
|
||||
"""
|
||||
Feature the class specified by <file_class_name>
|
||||
:param file_class_name: str, contains the name of the Python module followed by the name of the class.
|
||||
Example: "sequence_modeling.SeqLabeling"
|
||||
:return module: the model class
|
||||
"""
|
||||
import_prefix = "fastNLP.models."
|
||||
parts = (import_prefix + file_class_name).split(".")
|
||||
from_module = ".".join(parts[:-1])
|
||||
module = __import__(from_module)
|
||||
for sub in parts[1:]:
|
||||
module = getattr(module, sub)
|
||||
return module
|
||||
|
||||
def _create_inference(self, model_dir):
|
||||
"""Specify which task to perform.
|
||||
|
||||
:param model_dir:
|
||||
:return:
|
||||
"""
|
||||
if self.infer_type == "seq_label":
|
||||
return SeqLabelInfer(model_dir)
|
||||
elif self.infer_type == "text_class":
|
||||
return ClassificationInfer(model_dir)
|
||||
else:
|
||||
raise ValueError("fail to create inference instance")
|
||||
|
||||
def _create_data_set(self, infer_input):
|
||||
"""Create a DataSet object given the raw inputs.
|
||||
|
||||
:param infer_input: 2-D lists of strings
|
||||
:return data_set: a DataSet object
|
||||
"""
|
||||
if self.infer_type in ["seq_label", "text_class"]:
|
||||
data_set = convert_seq_dataset(infer_input)
|
||||
data_set.index_field("word_seq", self.word_vocab)
|
||||
if self.infer_type == "seq_label":
|
||||
data_set.set_origin_len("word_seq")
|
||||
return data_set
|
||||
else:
|
||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
|
||||
|
||||
|
||||
def _load(self, model_dir, model_name):
|
||||
|
||||
return 0
|
||||
|
||||
def _download(self, model_name, url):
|
||||
"""
|
||||
Download the model weights from <url> and save in <self.model_dir>.
|
||||
:param model_name:
|
||||
:param url:
|
||||
"""
|
||||
print("Downloading {} from {}".format(model_name, url))
|
||||
# TODO: download model via url
|
||||
|
||||
def model_exist(self, model_dir):
|
||||
"""
|
||||
Check whether the desired model is already in the directory.
|
||||
:param model_dir:
|
||||
"""
|
||||
return True
|
||||
|
||||
def tokenize(self, text, language):
|
||||
"""Extract tokens from strings.
|
||||
For English, extract words separated by space.
|
||||
For Chinese, extract characters.
|
||||
TODO: more complex tokenization methods
|
||||
|
||||
:param text: list of string
|
||||
:param language: str, one of ('zh', 'en'), Chinese or English.
|
||||
:return data: list of list of string, each string is a token.
|
||||
"""
|
||||
assert language in ("zh", "en")
|
||||
data = []
|
||||
for sent in text:
|
||||
if language == "en":
|
||||
tokens = sent.strip().split()
|
||||
elif language == "zh":
|
||||
tokens = [char for char in sent]
|
||||
else:
|
||||
raise RuntimeError("Unknown language {}".format(language))
|
||||
data.append(tokens)
|
||||
return data
|
||||
|
||||
def _make_output(self, results, infer_input):
|
||||
"""Transform the infer output into user-friendly output.
|
||||
|
||||
:param results: 1 or 2-D list of strings.
|
||||
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length]
|
||||
If self.infer_type == "text_class", it is of shape [num_examples]
|
||||
:param infer_input: 2-D list of string, the input query before inference.
|
||||
:return outputs: list. Each entry is a prediction.
|
||||
"""
|
||||
if self.infer_type == "seq_label":
|
||||
outputs = make_seq_label_output(results, infer_input)
|
||||
elif self.infer_type == "text_class":
|
||||
outputs = make_class_output(results, infer_input)
|
||||
else:
|
||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
|
||||
return outputs
|
||||
|
||||
|
||||
def make_seq_label_output(result, infer_input):
|
||||
"""Transform model output into user-friendly contents.
|
||||
|
||||
:param result: 2-D list of strings. (model output)
|
||||
:param infer_input: 2-D list of string (model input)
|
||||
:return ret: list of list of tuples
|
||||
[
|
||||
[(word_11, label_11), (word_12, label_12), ...],
|
||||
[(word_21, label_21), (word_22, label_22), ...],
|
||||
...
|
||||
]
|
||||
"""
|
||||
ret = []
|
||||
for example_x, example_y in zip(infer_input, result):
|
||||
ret.append([(x, y) for x, y in zip(example_x, example_y)])
|
||||
return ret
|
||||
|
||||
def make_class_output(result, infer_input):
|
||||
"""Transform model output into user-friendly contents.
|
||||
|
||||
:param result: 2-D list of strings. (model output)
|
||||
:param infer_input: 1-D list of string (model input)
|
||||
:return ret: the same as result, [label_1, label_2, ...]
|
||||
"""
|
||||
return result
|
||||
|
||||
|
||||
def interpret_word_seg_results(char_seq, label_seq):
|
||||
"""Transform model output into user-friendly contents.
|
||||
|
||||
Example: In CWS, convert <BMES> labeling into segmented text.
|
||||
:param char_seq: list of string,
|
||||
:param label_seq: list of string, the same length as char_seq
|
||||
Each entry is one of ('B', 'M', 'E', 'S').
|
||||
:return output: list of words
|
||||
"""
|
||||
words = []
|
||||
word = ""
|
||||
for char, label in zip(char_seq, label_seq):
|
||||
if label[0] == "B":
|
||||
if word != "":
|
||||
words.append(word)
|
||||
word = char
|
||||
elif label[0] == "M":
|
||||
word += char
|
||||
elif label[0] == "E":
|
||||
word += char
|
||||
words.append(word)
|
||||
word = ""
|
||||
elif label[0] == "S":
|
||||
if word != "":
|
||||
words.append(word)
|
||||
word = ""
|
||||
words.append(char)
|
||||
else:
|
||||
raise ValueError("invalid label {}".format(label[0]))
|
||||
return words
|
||||
|
||||
|
||||
def interpret_cws_pos_results(char_seq, label_seq):
|
||||
"""Transform model output into user-friendly contents.
|
||||
|
||||
:param char_seq: list of string
|
||||
:param label_seq: list of string, the same length as char_seq.
|
||||
:return outputs: list of tuple (words, pos_tag):
|
||||
"""
|
||||
|
||||
def pos_tag_check(seq):
|
||||
"""check whether all entries are the same """
|
||||
return len(set(seq)) <= 1
|
||||
|
||||
word = []
|
||||
word_pos = []
|
||||
outputs = []
|
||||
for char, label in zip(char_seq, label_seq):
|
||||
tmp = label.split("-")
|
||||
cws_label, pos_tag = tmp[0], tmp[1]
|
||||
|
||||
if cws_label == "B" or cws_label == "M":
|
||||
word.append(char)
|
||||
word_pos.append(pos_tag)
|
||||
elif cws_label == "E":
|
||||
word.append(char)
|
||||
word_pos.append(pos_tag)
|
||||
if not pos_tag_check(word_pos):
|
||||
raise RuntimeError("character-wise pos tags inconsistent. ")
|
||||
outputs.append(("".join(word), word_pos[0]))
|
||||
word.clear()
|
||||
word_pos.clear()
|
||||
elif cws_label == "S":
|
||||
outputs.append((char, pos_tag))
|
||||
return outputs
|
@ -2,7 +2,7 @@ import configparser
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ConfigLoader(BaseLoader):
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.saver.logger import create_logger
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.logger import create_logger
|
||||
|
||||
|
||||
class ConfigSaver(object):
|
@ -3,7 +3,7 @@ import os
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import *
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
def convert_seq_dataset(data):
|
@ -1,10 +1,7 @@
|
||||
import _pickle
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class EmbedLoader(BaseLoader):
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ModelLoader(BaseLoader):
|
||||
@ -19,10 +19,10 @@ class ModelLoader(BaseLoader):
|
||||
:param model_path: str, the path to the saved model.
|
||||
"""
|
||||
empty_model.load_state_dict(torch.load(model_path))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch(model_path):
|
||||
def load_pytorch_model(model_path):
|
||||
"""Load the entire model.
|
||||
|
||||
"""
|
||||
return torch.load(model_path)
|
||||
return torch.load(model_path)
|
@ -1,13 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TimestepDropout(torch.nn.Dropout):
|
||||
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
|
||||
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
|
||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
|
||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
|
||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
|
||||
if self.inplace:
|
||||
x *= dropout_mask
|
||||
return
|
||||
|
@ -1,13 +1,11 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
|
||||
|
||||
from fastNLP.api.processor import *
|
||||
from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
|
||||
import _pickle as pickle
|
||||
import torch
|
||||
|
@ -1,11 +1,9 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
|
||||
from fastNLP.core.dataset import DataSet
|
@ -3,8 +3,6 @@ import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import torch
|
||||
import re
|
||||
|
||||
@ -13,16 +11,13 @@ from fastNLP.core.metrics import Evaluator
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.core.field import TextField, SeqLabelField
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.loader.embed_loader import EmbedLoader
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
|
||||
BOS = '<BOS>'
|
||||
EOS = '<EOS>'
|
||||
|
@ -1,10 +1,10 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.core.preprocess import ClassPreprocess as Preprocess
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader
|
||||
from fastNLP.loader.config_loader import ConfigSection
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader
|
||||
from fastNLP.core.utils import ClassPreprocess as Preprocess
|
||||
from fastNLP.io.config_loader import ConfigLoader
|
||||
from fastNLP.io.config_loader import ConfigSection
|
||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.aggregator.self_attention import SelfAttention
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
|
@ -1,8 +1,8 @@
|
||||
|
||||
|
||||
from fastNLP.loader.dataset_loader import DataSetLoader
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.io.dataset_loader import DataSetLoader
|
||||
|
||||
|
||||
def cut_long_sentence(sent, max_sample_length=200):
|
||||
|
@ -3,17 +3,16 @@ import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.dataset_loader import BaseLoader, TokenizeDataSetLoader
|
||||
from fastNLP.core.preprocess import load_pickle
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
|
||||
from fastNLP.core.utils import load_pickle
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.utils import save_pickle
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
|
||||
# not in the file's dir
|
||||
|
@ -13,8 +13,8 @@ from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.dataset_loader import PeopleDailyCorpusLoader
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset
|
||||
from fastNLP.io.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset
|
||||
|
||||
|
||||
class TestDataSet(unittest.TestCase):
|
||||
|
@ -1,12 +1,10 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.utils import save_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.base_loader import BaseLoader
|
||||
from fastNLP.loader.dataset_loader import convert_seq_dataset
|
||||
from fastNLP.io.dataset_loader import convert_seq_dataset
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
|
||||
|
@ -3,7 +3,7 @@ import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
|
||||
|
||||
class TestConfigLoader(unittest.TestCase):
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.saver.config_saver import ConfigSaver
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.config_saver import ConfigSaver
|
||||
|
||||
|
||||
class TestConfigSaver(unittest.TestCase):
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
||||
PeopleDailyCorpusLoader, ConllLoader
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.io.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
||||
PeopleDailyCorpusLoader, ConllLoader
|
||||
|
||||
|
||||
class TestDatasetLoader(unittest.TestCase):
|
||||
def test_case_1(self):
|
@ -1,10 +1,8 @@
|
||||
import unittest
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.loader.embed_loader import EmbedLoader
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
|
||||
|
||||
class TestEmbedLoader(unittest.TestCase):
|
@ -3,17 +3,17 @@ import sys
|
||||
|
||||
sys.path.append("..")
|
||||
import argparse
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import BaseLoader
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.dataset_loader import BaseLoader
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
from fastNLP.core.utils import save_pickle, load_pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
|
||||
|
@ -1,17 +1,16 @@
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.core.utils import save_pickle, load_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
data_name = "pku_training.utf8"
|
||||
cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
|
||||
|
@ -2,15 +2,15 @@ import os
|
||||
|
||||
from fastNLP.core.metrics import SeqLabelEvaluator
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.core.utils import save_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.dataset_loader import TokenizeDataSetLoader
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
pickle_path = "./seq_label/"
|
||||
model_name = "seq_label_model.pkl"
|
||||
|
@ -8,15 +8,15 @@ import sys
|
||||
sys.path.append("..")
|
||||
from fastNLP.core.predictor import ClassificationInfer
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.dataset_loader import ClassDataSetLoader
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.loss import Loss
|
||||
from fastNLP.core.dataset import TextClassifyDataSet
|
||||
from fastNLP.core.preprocess import save_pickle, load_pickle
|
||||
from fastNLP.core.utils import save_pickle, load_pickle
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
|
||||
|
@ -1,213 +0,0 @@
|
||||
# encoding: utf-8
|
||||
import os
|
||||
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.fastnlp import FastNLP
|
||||
from fastNLP.fastnlp import interpret_word_seg_results, interpret_cws_pos_results
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
|
||||
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"
|
||||
PATH_TO_POS_TAG_PICKLE_FILES = "/home/zyfeng/data/crf_seg/"
|
||||
PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES = "/home/zyfeng/data/text_classify/"
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
|
||||
|
||||
def word_seg(model_dir, config, section):
|
||||
nlp = FastNLP(model_dir=model_dir)
|
||||
nlp.load("cws_basic_model", config_file=config, section_name=section)
|
||||
text = ["这是最好的基于深度学习的中文分词系统。",
|
||||
"大王叫我来巡山。",
|
||||
"我党多年来致力于改善人民生活水平。"]
|
||||
results = nlp.run(text)
|
||||
print(results)
|
||||
for example in results:
|
||||
words, labels = [], []
|
||||
for res in example:
|
||||
words.append(res[0])
|
||||
labels.append(res[1])
|
||||
print(interpret_word_seg_results(words, labels))
|
||||
|
||||
|
||||
def mock_cws():
|
||||
os.makedirs("mock", exist_ok=True)
|
||||
text = ["这是最好的基于深度学习的中文分词系统。",
|
||||
"大王叫我来巡山。",
|
||||
"我党多年来致力于改善人民生活水平。"]
|
||||
|
||||
word2id = Vocabulary()
|
||||
word_list = [ch for ch in "".join(text)]
|
||||
word2id.update(word_list)
|
||||
save_pickle(word2id, "./mock/", "word2id.pkl")
|
||||
|
||||
class2id = Vocabulary(need_default=False)
|
||||
label_list = ['B', 'M', 'E', 'S']
|
||||
class2id.update(label_list)
|
||||
save_pickle(class2id, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
|
||||
config_file = """
|
||||
[test_section]
|
||||
vocab_size = {}
|
||||
word_emb_dim = 50
|
||||
rnn_hidden_units = 50
|
||||
num_classes = {}
|
||||
""".format(len(word2id), len(class2id))
|
||||
with open("mock/test.cfg", "w", encoding="utf-8") as f:
|
||||
f.write(config_file)
|
||||
|
||||
model = AdvSeqLabel(model_args)
|
||||
ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model)
|
||||
|
||||
|
||||
def test_word_seg():
|
||||
# fake the model and pickles
|
||||
print("start mocking")
|
||||
mock_cws()
|
||||
# run the inference codes
|
||||
print("start testing")
|
||||
word_seg("./mock/", "test.cfg", "test_section")
|
||||
# clean up environments
|
||||
print("clean up")
|
||||
os.system("rm -rf mock")
|
||||
|
||||
|
||||
def pos_tag(model_dir, config, section):
|
||||
nlp = FastNLP(model_dir=model_dir)
|
||||
nlp.load("pos_tag_model", config_file=config, section_name=section)
|
||||
text = ["这是最好的基于深度学习的中文分词系统。",
|
||||
"大王叫我来巡山。",
|
||||
"我党多年来致力于改善人民生活水平。"]
|
||||
results = nlp.run(text)
|
||||
for example in results:
|
||||
words, labels = [], []
|
||||
for res in example:
|
||||
words.append(res[0])
|
||||
labels.append(res[1])
|
||||
try:
|
||||
print(interpret_cws_pos_results(words, labels))
|
||||
except RuntimeError:
|
||||
print("inconsistent pos tags. this is for test only.")
|
||||
|
||||
|
||||
def mock_pos_tag():
|
||||
os.makedirs("mock", exist_ok=True)
|
||||
text = ["这是最好的基于深度学习的中文分词系统。",
|
||||
"大王叫我来巡山。",
|
||||
"我党多年来致力于改善人民生活水平。"]
|
||||
|
||||
vocab = Vocabulary()
|
||||
word_list = [ch for ch in "".join(text)]
|
||||
vocab.update(word_list)
|
||||
save_pickle(vocab, "./mock/", "word2id.pkl")
|
||||
|
||||
idx2label = Vocabulary(need_default=False)
|
||||
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
|
||||
idx2label.update(label_list)
|
||||
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||
config_file = """
|
||||
[test_section]
|
||||
vocab_size = {}
|
||||
word_emb_dim = 50
|
||||
rnn_hidden_units = 50
|
||||
num_classes = {}
|
||||
""".format(len(vocab), len(idx2label))
|
||||
with open("mock/test.cfg", "w", encoding="utf-8") as f:
|
||||
f.write(config_file)
|
||||
|
||||
model = AdvSeqLabel(model_args)
|
||||
ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model)
|
||||
|
||||
|
||||
def test_pos_tag():
|
||||
mock_pos_tag()
|
||||
pos_tag("./mock/", "test.cfg", "test_section")
|
||||
os.system("rm -rf mock")
|
||||
|
||||
|
||||
def text_classify(model_dir, config, section):
|
||||
nlp = FastNLP(model_dir=model_dir)
|
||||
nlp.load("text_classify_model", config_file=config, section_name=section)
|
||||
text = [
|
||||
"世界物联网大会明日在京召开龙头股启动在即",
|
||||
"乌鲁木齐市新增一处城市中心旅游目的地",
|
||||
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"]
|
||||
results = nlp.run(text)
|
||||
print(results)
|
||||
|
||||
|
||||
def mock_text_classify():
|
||||
os.makedirs("mock", exist_ok=True)
|
||||
text = ["世界物联网大会明日在京召开龙头股启动在即",
|
||||
"乌鲁木齐市新增一处城市中心旅游目的地",
|
||||
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"
|
||||
]
|
||||
vocab = Vocabulary()
|
||||
word_list = [ch for ch in "".join(text)]
|
||||
vocab.update(word_list)
|
||||
save_pickle(vocab, "./mock/", "word2id.pkl")
|
||||
|
||||
idx2label = Vocabulary(need_default=False)
|
||||
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
|
||||
idx2label.update(label_list)
|
||||
save_pickle(idx2label, "./mock/", "label2id.pkl")
|
||||
|
||||
model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
|
||||
config_file = """
|
||||
[test_section]
|
||||
vocab_size = {}
|
||||
word_emb_dim = 50
|
||||
rnn_hidden_units = 50
|
||||
num_classes = {}
|
||||
""".format(len(vocab), len(idx2label))
|
||||
with open("mock/test.cfg", "w", encoding="utf-8") as f:
|
||||
f.write(config_file)
|
||||
|
||||
model = CNNText(model_args)
|
||||
ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model)
|
||||
|
||||
|
||||
def test_text_classify():
|
||||
mock_text_classify()
|
||||
text_classify("./mock/", "test.cfg", "test_section")
|
||||
os.system("rm -rf mock")
|
||||
|
||||
|
||||
def test_word_seg_interpret():
|
||||
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'),
|
||||
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'),
|
||||
('。', 'S')]]
|
||||
chars = [x[0] for x in foo[0]]
|
||||
labels = [x[1] for x in foo[0]]
|
||||
print(interpret_word_seg_results(chars, labels))
|
||||
|
||||
|
||||
def test_interpret_cws_pos_results():
|
||||
foo = [
|
||||
[('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'),
|
||||
('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'),
|
||||
('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')]
|
||||
]
|
||||
chars = [x[0] for x in foo[0]]
|
||||
labels = [x[1] for x in foo[0]]
|
||||
print(interpret_cws_pos_results(chars, labels))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_word_seg()
|
||||
test_pos_tag()
|
||||
test_text_classify()
|
||||
test_word_seg_interpret()
|
||||
test_interpret_cws_pos_results()
|
Loading…
Reference in New Issue
Block a user