mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
changes to preprocessor, trainer, inference & seq modeling
- [trainer]rename "batchify" to "make_batch" in trainer - [trainer]pack (batch_x_pad, seq_len) into batch_x in make_batch for seq labeling, because seq length before pad is needed to make masks - [trainer]unpack it in data_forward - [model]shorten model definition - [inference]build inference class. test_POS_pipeline.py is OK to infer - [preprocessor]handle pickles in a nicer manner - [FastNLP] add fastNLP.py as high-level API, not finished yet
This commit is contained in:
parent
22d900b7a3
commit
fe17f611b6
@ -1,26 +1,116 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.action.action import Batchifier, SequentialSampler
|
||||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
|
||||
|
||||
|
||||
class Inference(object):
|
||||
"""
|
||||
This is an interface focusing on predicting output based on trained models.
|
||||
It does not care about evaluations of the model.
|
||||
It does not care about evaluations of the model, which is different from Tester.
|
||||
This is a high-level model wrapper to be called by FastNLP.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, pickle_path):
|
||||
self.batch_size = 1
|
||||
self.batch_output = []
|
||||
self.iterator = None
|
||||
self.pickle_path = pickle_path
|
||||
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
|
||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
|
||||
def predict(self, model, data):
|
||||
def predict(self, network, data):
|
||||
"""
|
||||
this is actually a forward pass. shall be shared by Trainer/Tester
|
||||
:param model:
|
||||
:param data:
|
||||
:return result: the output results
|
||||
Perform inference.
|
||||
:param network:
|
||||
:param data: multi-level lists of strings
|
||||
:return result: the model outputs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# transform strings into indices
|
||||
data = self.prepare_input(data)
|
||||
|
||||
def prepare_input(self, data_path):
|
||||
# turn on the testing mode; clean up the history
|
||||
self.mode(network, test=True)
|
||||
|
||||
self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
|
||||
|
||||
num_iter = len(data) // self.batch_size
|
||||
|
||||
for step in range(num_iter):
|
||||
batch_x = self.batchify(data)
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
self.batch_output.append(prediction)
|
||||
|
||||
return self.prepare_output(self.batch_output)
|
||||
|
||||
def mode(self, network, test=True):
|
||||
if test:
|
||||
network.eval()
|
||||
else:
|
||||
network.train()
|
||||
self.batch_output.clear()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
"""
|
||||
This can also be shared.
|
||||
:param data_path:
|
||||
This is only for sequence labeling with CRF decoder. To do: more general ?
|
||||
:param network:
|
||||
:param x:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
seq_len = [len(seq) for seq in x]
|
||||
x = torch.Tensor(x).long()
|
||||
y = network(x)
|
||||
prediction = network.prediction(y, seq_len)
|
||||
# To do: hide framework
|
||||
results = torch.Tensor(prediction).view(-1, )
|
||||
return list(results.data)
|
||||
|
||||
def batchify(self, data):
|
||||
indices = next(self.iterator)
|
||||
batch_x = [data[idx] for idx in indices]
|
||||
batch_x = self.pad(batch_x)
|
||||
return batch_x
|
||||
|
||||
@staticmethod
|
||||
def pad(batch, fill=0):
|
||||
"""
|
||||
Pad a batch of samples to maximum length.
|
||||
:param batch: list of list
|
||||
:param fill: word index to pad, default 0.
|
||||
:return: a padded batch
|
||||
"""
|
||||
max_length = max([len(x) for x in batch])
|
||||
for idx, sample in enumerate(batch):
|
||||
if len(sample) < max_length:
|
||||
batch[idx] = sample + [fill * (max_length - len(sample))]
|
||||
return batch
|
||||
|
||||
def prepare_input(self, data):
|
||||
"""
|
||||
Transform three-level list of strings into that of index.
|
||||
:param data:
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
[word_21, word_22, ...],
|
||||
...
|
||||
]
|
||||
"""
|
||||
data_index = []
|
||||
default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL]
|
||||
for example in data:
|
||||
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
|
||||
return data_index
|
||||
|
||||
def prepare_output(self, batch_outputs):
|
||||
"""
|
||||
Transform list of batch outputs into strings.
|
||||
:param batch_outputs: list of list [num_batch, tag_seq_length]
|
||||
:return:
|
||||
"""
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
results.append([self.index2label[int(x.data)] for x in batch])
|
||||
return results
|
||||
|
@ -86,7 +86,7 @@ class BaseTrainer(Action):
|
||||
|
||||
# training iterations in one epoch
|
||||
for step in range(iterations):
|
||||
batch_x, batch_y = self.batchify(data_train) # pad ?
|
||||
batch_x, batch_y = self.make_batch(data_train)
|
||||
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
@ -180,7 +180,7 @@ class BaseTrainer(Action):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def batchify(self, data, output_length=True):
|
||||
def make_batch(self, data, output_length=True):
|
||||
"""
|
||||
1. Perform batching from data and produce a batch of training data.
|
||||
2. Add padding.
|
||||
@ -191,9 +191,12 @@ class BaseTrainer(Action):
|
||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
|
||||
...
|
||||
]
|
||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
|
||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
seq_len: list. The length of the pre-padded sequence, if output_length is True.
|
||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
|
||||
seq_len: list. The length of the pre-padded sequence, if output_length is True.
|
||||
|
||||
return batch_x and batch_y, if output_length is False
|
||||
"""
|
||||
indices = next(self.iterator)
|
||||
batch = [data[idx] for idx in indices]
|
||||
@ -202,7 +205,7 @@ class BaseTrainer(Action):
|
||||
batch_x_pad = self.pad(batch_x)
|
||||
if output_length:
|
||||
seq_len = [len(x) for x in batch_x]
|
||||
return batch_x_pad, batch_y, seq_len
|
||||
return (batch_x_pad, seq_len), batch_y
|
||||
else:
|
||||
return batch_x_pad, batch_y
|
||||
|
||||
@ -292,17 +295,23 @@ class POSTrainer(BaseTrainer):
|
||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
|
||||
return data_train, data_dev, 0, 1
|
||||
|
||||
def data_forward(self, network, x):
|
||||
def data_forward(self, network, inputs):
|
||||
"""
|
||||
:param network: the PyTorch model
|
||||
:param x: list of list, [batch_size, max_len]
|
||||
:param inputs: list of list, [batch_size, max_len],
|
||||
or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len]
|
||||
:return y: [batch_size, max_len, tag_size]
|
||||
"""
|
||||
self.seq_len = [len(seq) for seq in x]
|
||||
# unpack the returned value from make_batch
|
||||
if isinstance(inputs, tuple):
|
||||
x = inputs[0]
|
||||
self.seq_len = inputs[1]
|
||||
else:
|
||||
x = inputs
|
||||
x = torch.Tensor(x).long()
|
||||
self.batch_size = x.size(0)
|
||||
self.max_len = x.size(1)
|
||||
# self.mask = seq_mask(seq_len, self.max_len)
|
||||
|
||||
y = network(x)
|
||||
return y
|
||||
|
||||
@ -325,11 +334,12 @@ class POSTrainer(BaseTrainer):
|
||||
def get_loss(self, predict, truth):
|
||||
"""
|
||||
Compute loss given prediction and ground truth.
|
||||
:param predict: prediction label vector, [batch_size, tag_size, tag_size]
|
||||
:param predict: prediction label vector, [batch_size, max_len, tag_size]
|
||||
:param truth: ground truth label vector, [batch_size, max_len]
|
||||
:return: a scalar
|
||||
"""
|
||||
truth = torch.Tensor(truth)
|
||||
assert truth.shape == (self.batch_size, self.max_len)
|
||||
if self.loss_func is None:
|
||||
if hasattr(self.model, "loss"):
|
||||
self.loss_func = self.model.loss
|
||||
@ -347,6 +357,35 @@ class POSTrainer(BaseTrainer):
|
||||
else:
|
||||
return False
|
||||
|
||||
def make_batch(self, data, output_length=True):
|
||||
"""
|
||||
1. Perform batching from data and produce a batch of training data.
|
||||
2. Add padding.
|
||||
:param data: list. Each entry is a sample, which is also a list of features and label(s).
|
||||
E.g.
|
||||
[
|
||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
|
||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
|
||||
...
|
||||
]
|
||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
|
||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
|
||||
seq_len: list. The length of the pre-padded sequence, if output_length is True.
|
||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
|
||||
|
||||
return batch_x and batch_y, if output_length is False
|
||||
"""
|
||||
indices = next(self.iterator)
|
||||
batch = [data[idx] for idx in indices]
|
||||
batch_x = [sample[0] for sample in batch]
|
||||
batch_y = [sample[1] for sample in batch]
|
||||
batch_x_pad = self.pad(batch_x)
|
||||
if output_length:
|
||||
seq_len = [len(x) for x in batch_x]
|
||||
return (batch_x_pad, seq_len), batch_y
|
||||
else:
|
||||
return batch_x_pad, batch_y
|
||||
|
||||
|
||||
class LanguageModelTrainer(BaseTrainer):
|
||||
"""
|
||||
@ -438,7 +477,7 @@ class ClassTrainer(BaseTrainer):
|
||||
|
||||
# training iterations in one epoch
|
||||
step = 0
|
||||
for batch_x, batch_y in self.batchify(data_train):
|
||||
for batch_x, batch_y in self.make_batch(data_train):
|
||||
prediction = self.data_forward(network, batch_x)
|
||||
|
||||
loss = self.get_loss(prediction, batch_y)
|
||||
@ -533,7 +572,7 @@ class ClassTrainer(BaseTrainer):
|
||||
"""Apply gradient."""
|
||||
self.optimizer.step()
|
||||
|
||||
def batchify(self, data):
|
||||
def make_batch(self, data):
|
||||
"""Batch and pad data."""
|
||||
for indices in self.iterator:
|
||||
batch = [data[idx] for idx in indices]
|
||||
@ -559,4 +598,4 @@ if __name__ == "__name__":
|
||||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
|
||||
trainer = BaseTrainer(train_args)
|
||||
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
|
||||
trainer.batchify(data=data_train)
|
||||
trainer.make_batch(data=data_train)
|
||||
|
104
fastNLP/fastNLP.py
Normal file
104
fastNLP/fastNLP.py
Normal file
@ -0,0 +1,104 @@
|
||||
from fastNLP.action.inference import Inference
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
|
||||
"""
|
||||
mapping from model name to [URL, file_name.class_name]
|
||||
Notice that the class of the model should be in "models" directory.
|
||||
|
||||
Example:
|
||||
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling"]
|
||||
"""
|
||||
FastNLP_MODEL_COLLECTION = {
|
||||
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling"]
|
||||
}
|
||||
|
||||
|
||||
class FastNLP(object):
|
||||
"""
|
||||
High-level interface for direct model inference.
|
||||
Usage:
|
||||
fastnlp = FastNLP()
|
||||
fastnlp.load("zh_pos_tag_model")
|
||||
text = "这是最好的基于深度学习的中文分词系统。"
|
||||
result = fastnlp.run(text)
|
||||
print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"]
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir="./"):
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
|
||||
def load(self, model_name):
|
||||
"""
|
||||
Load a pre-trained FastNLP model together with additional data.
|
||||
:param model_name: str, the name of a FastNLP model.
|
||||
"""
|
||||
assert type(model_name) is str
|
||||
if model_name not in FastNLP_MODEL_COLLECTION:
|
||||
raise ValueError("No FastNLP model named {}.".format(model_name))
|
||||
|
||||
if not self.model_exist(model_dir=self.model_dir):
|
||||
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0])
|
||||
|
||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1])
|
||||
|
||||
model_args = ConfigSection()
|
||||
# To do: customized config file for model init parameters
|
||||
ConfigLoader.load_config(self.model_dir + "default.cfg", model_args)
|
||||
|
||||
model = model_class(model_args)
|
||||
|
||||
# To do: framework independent
|
||||
ModelLoader.load_pytorch(model, self.model_dir + model_name)
|
||||
|
||||
self.model = model
|
||||
|
||||
print("Model loaded. ")
|
||||
|
||||
def run(self, infer_input):
|
||||
"""
|
||||
Perform inference over given input using the loaded model.
|
||||
:param infer_input: str, raw text
|
||||
:return results:
|
||||
"""
|
||||
infer = Inference()
|
||||
data = infer.prepare_input(infer_input)
|
||||
results = infer.predict(self.model, data)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _get_model_class(file_class_name):
|
||||
"""
|
||||
Feature the class specified by <file_class_name>
|
||||
:param file_class_name: str, contains the name of the Python module followed by the name of the class.
|
||||
Example: "sequence_modeling.SeqLabeling"
|
||||
:return module: the model class
|
||||
"""
|
||||
import_prefix = "fastNLP.models."
|
||||
parts = (import_prefix + file_class_name).split(".")
|
||||
from_module = ".".join(parts[:-1])
|
||||
module = __import__(from_module)
|
||||
for sub in parts[1:]:
|
||||
module = getattr(module, sub)
|
||||
return module
|
||||
|
||||
def _load(self, model_dir, model_name):
|
||||
# To do
|
||||
return 0
|
||||
|
||||
def _download(self, model_name, url):
|
||||
"""
|
||||
Download the model weights from <url> and save in <self.model_dir>.
|
||||
:param model_name:
|
||||
:param url:
|
||||
"""
|
||||
print("Downloading {} from {}".format(model_name, url))
|
||||
# To do
|
||||
|
||||
def model_exist(self, model_dir):
|
||||
"""
|
||||
Check whether the desired model is already in the directory.
|
||||
:param model_dir:
|
||||
"""
|
||||
pass
|
@ -17,7 +17,7 @@ class BaseLoader(object):
|
||||
def load_lines(self):
|
||||
with open(self.data_path, "r", encoding="utf=8") as f:
|
||||
text = f.readlines()
|
||||
return text
|
||||
return [line.strip() for line in text]
|
||||
|
||||
|
||||
class ToyLoader0(BaseLoader):
|
||||
|
@ -11,9 +11,11 @@ class ModelLoader(BaseLoader):
|
||||
def __init__(self, data_name, data_path):
|
||||
super(ModelLoader, self).__init__(data_name, data_path)
|
||||
|
||||
def load_pytorch(self, empty_model):
|
||||
@staticmethod
|
||||
def load_pytorch(empty_model, model_path):
|
||||
"""
|
||||
Load model parameters from .pkl files into the empty PyTorch model.
|
||||
:param empty_model: a PyTorch model with initialized parameters.
|
||||
:param model_path: str, the path to the saved model.
|
||||
"""
|
||||
empty_model.load_state_dict(torch.load(self.data_path))
|
||||
empty_model.load_state_dict(torch.load(model_path))
|
||||
|
@ -1,346 +1,361 @@
|
||||
import _pickle
|
||||
import os
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
|
||||
|
||||
# the first vocab in dict with the index = 5
|
||||
|
||||
|
||||
class BasePreprocess(object):
|
||||
|
||||
def __init__(self, data, pickle_path):
|
||||
super(BasePreprocess, self).__init__()
|
||||
self.data = data
|
||||
self.pickle_path = pickle_path
|
||||
if not self.pickle_path.endswith('/'):
|
||||
self.pickle_path = self.pickle_path + '/'
|
||||
|
||||
|
||||
class POSPreprocess(BasePreprocess):
|
||||
"""
|
||||
This class are used to preprocess the pos datasets.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, pickle_path="./", train_dev_split=0):
|
||||
"""
|
||||
Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||
from labels/classes to index, from index to labels/classes.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:param pickle_path: str, the directory to the pickle files. Default: "./"
|
||||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
|
||||
|
||||
To do:
|
||||
1. simplify __init__
|
||||
"""
|
||||
super(POSPreprocess, self).__init__(data, pickle_path)
|
||||
|
||||
self.pickle_path = pickle_path
|
||||
|
||||
if self.pickle_exist("word2id.pkl"):
|
||||
# load word2index because the construction of the following objects needs it
|
||||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f:
|
||||
self.word2index = _pickle.load(f)
|
||||
else:
|
||||
self.word2index, self.label2index = self.build_dict(data)
|
||||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
|
||||
_pickle.dump(self.word2index, f)
|
||||
|
||||
if self.pickle_exist("class2id.pkl"):
|
||||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f:
|
||||
self.label2index = _pickle.load(f)
|
||||
else:
|
||||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
|
||||
_pickle.dump(self.label2index, f)
|
||||
#something will be wrong if word2id.pkl is found but class2id.pkl is not found
|
||||
|
||||
if not self.pickle_exist("id2word.pkl"):
|
||||
index2word = self.build_reverse_dict(self.word2index)
|
||||
with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f:
|
||||
_pickle.dump(index2word, f)
|
||||
|
||||
if not self.pickle_exist("id2class.pkl"):
|
||||
index2label = self.build_reverse_dict(self.label2index)
|
||||
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
|
||||
_pickle.dump(index2label, f)
|
||||
|
||||
if not self.pickle_exist("data_train.pkl"):
|
||||
data_train = self.to_index(data)
|
||||
if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"):
|
||||
data_dev = data_train[: int(len(data_train) * train_dev_split)]
|
||||
with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f:
|
||||
_pickle.dump(data_dev, f)
|
||||
with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f:
|
||||
_pickle.dump(data_train, f)
|
||||
|
||||
def build_dict(self, data):
|
||||
"""
|
||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return word2index: dict of {str, int}
|
||||
label2index: dict of {str, int}
|
||||
"""
|
||||
label2index = {}
|
||||
word2index = DEFAULT_WORD_TO_INDEX
|
||||
for example in data:
|
||||
for word, label in zip(example[0], example[1]):
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
return word2index, label2index
|
||||
|
||||
def pickle_exist(self, pickle_name):
|
||||
"""
|
||||
:param pickle_name: the filename of target pickle file
|
||||
:return: True if file exists else False
|
||||
"""
|
||||
if not os.path.exists(self.pickle_path):
|
||||
os.makedirs(self.pickle_path)
|
||||
file_name = os.path.join(self.pickle_path, pickle_name)
|
||||
if os.path.exists(file_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def build_reverse_dict(self, word_dict):
|
||||
id2word = {word_dict[w]: w for w in word_dict}
|
||||
return id2word
|
||||
|
||||
def to_index(self, data):
|
||||
"""
|
||||
Convert word strings and label strings into indices.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return data_index: the shape of data, but each string is replaced by its corresponding index
|
||||
"""
|
||||
data_index = []
|
||||
for example in data:
|
||||
word_list = []
|
||||
label_list = []
|
||||
for word, label in zip(example[0], example[1]):
|
||||
word_list.append(self.word2index[word])
|
||||
label_list.append(self.label2index[label])
|
||||
data_index.append([word_list, label_list])
|
||||
return data_index
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2index)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.label2index)
|
||||
|
||||
|
||||
class ClassPreprocess(BasePreprocess):
|
||||
"""
|
||||
Pre-process the classification datasets.
|
||||
|
||||
Params:
|
||||
pickle_path - directory to save result of pre-processing
|
||||
Saves:
|
||||
word2id.pkl
|
||||
id2word.pkl
|
||||
class2id.pkl
|
||||
id2class.pkl
|
||||
embedding.pkl
|
||||
data_train.pkl
|
||||
data_dev.pkl
|
||||
data_test.pkl
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
# super(ClassPreprocess, self).__init__(data, pickle_path)
|
||||
self.word_dict = None
|
||||
self.label_dict = None
|
||||
self.pickle_path = pickle_path # save directory
|
||||
|
||||
def process(self, data, save_name):
|
||||
"""
|
||||
Process data.
|
||||
|
||||
Params:
|
||||
data - nested list, data = [sample1, sample2, ...],
|
||||
sample = [sentence, label], sentence = [word1, word2, ...]
|
||||
save_name - name of processed data, such as data_train.pkl
|
||||
Returns:
|
||||
vocab_size - vocabulary size
|
||||
n_classes - number of classes
|
||||
"""
|
||||
self.build_dict(data)
|
||||
self.word2id()
|
||||
vocab_size = self.id2word()
|
||||
self.class2id()
|
||||
num_classes = self.id2class()
|
||||
self.embedding()
|
||||
self.data_generate(data, save_name)
|
||||
|
||||
return vocab_size, num_classes
|
||||
|
||||
def build_dict(self, data):
|
||||
"""Build vocabulary."""
|
||||
|
||||
# just read if word2id.pkl and class2id.pkl exists
|
||||
if self.pickle_exist("word2id.pkl") and \
|
||||
self.pickle_exist("class2id.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "word2id.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
self.word_dict = _pickle.load(f)
|
||||
file_name = os.path.join(self.pickle_path, "class2id.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
self.label_dict = _pickle.load(f)
|
||||
return
|
||||
|
||||
# build vocabulary from scratch if nothing exists
|
||||
self.word_dict = {
|
||||
DEFAULT_PADDING_LABEL: 0,
|
||||
DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2,
|
||||
DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
self.label_dict = {}
|
||||
|
||||
# collect every word and label
|
||||
for sent, label in data:
|
||||
if len(sent) <= 1:
|
||||
continue
|
||||
|
||||
if label not in self.label_dict:
|
||||
index = len(self.label_dict)
|
||||
self.label_dict[label] = index
|
||||
|
||||
for word in sent:
|
||||
if word not in self.word_dict:
|
||||
index = len(self.word_dict)
|
||||
self.word_dict[word[0]] = index
|
||||
|
||||
def pickle_exist(self, pickle_name):
|
||||
"""
|
||||
Check whether a pickle file exists.
|
||||
|
||||
Params
|
||||
pickle_name: the filename of target pickle file
|
||||
Return
|
||||
True if file exists else False
|
||||
"""
|
||||
if not os.path.exists(self.pickle_path):
|
||||
os.makedirs(self.pickle_path)
|
||||
file_name = os.path.join(self.pickle_path, pickle_name)
|
||||
if os.path.exists(file_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def word2id(self):
|
||||
"""Save vocabulary of {word:id} mapping format."""
|
||||
# nothing will be done if word2id.pkl exists
|
||||
if self.pickle_exist("word2id.pkl"):
|
||||
return
|
||||
|
||||
file_name = os.path.join(self.pickle_path, "word2id.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(self.word_dict, f)
|
||||
|
||||
def id2word(self):
|
||||
"""Save vocabulary of {id:word} mapping format."""
|
||||
# nothing will be done if id2word.pkl exists
|
||||
if self.pickle_exist("id2word.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "id2word.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
id2word_dict = _pickle.load(f)
|
||||
return len(id2word_dict)
|
||||
|
||||
id2word_dict = {self.word_dict[w]: w for w in self.word_dict}
|
||||
file_name = os.path.join(self.pickle_path, "id2word.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(id2word_dict, f)
|
||||
return len(id2word_dict)
|
||||
|
||||
def class2id(self):
|
||||
"""Save mapping of {class:id}."""
|
||||
# nothing will be done if class2id.pkl exists
|
||||
if self.pickle_exist("class2id.pkl"):
|
||||
return
|
||||
|
||||
file_name = os.path.join(self.pickle_path, "class2id.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(self.label_dict, f)
|
||||
|
||||
def id2class(self):
|
||||
"""Save mapping of {id:class}."""
|
||||
# nothing will be done if id2class.pkl exists
|
||||
if self.pickle_exist("id2class.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "id2class.pkl")
|
||||
with open(file_name, "rb") as f:
|
||||
id2class_dict = _pickle.load(f)
|
||||
return len(id2class_dict)
|
||||
|
||||
id2class_dict = {self.label_dict[c]: c for c in self.label_dict}
|
||||
file_name = os.path.join(self.pickle_path, "id2class.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(id2class_dict, f)
|
||||
return len(id2class_dict)
|
||||
|
||||
def embedding(self):
|
||||
"""Save embedding lookup table corresponding to vocabulary."""
|
||||
# nothing will be done if embedding.pkl exists
|
||||
if self.pickle_exist("embedding.pkl"):
|
||||
return
|
||||
|
||||
# retrieve vocabulary from pre-trained embedding (not implemented)
|
||||
|
||||
def data_generate(self, data_src, save_name):
|
||||
"""Convert dataset from text to digit."""
|
||||
|
||||
# nothing will be done if file exists
|
||||
save_path = os.path.join(self.pickle_path, save_name)
|
||||
if os.path.exists(save_path):
|
||||
return
|
||||
|
||||
data = []
|
||||
# for every sample
|
||||
for sent, label in data_src:
|
||||
if len(sent) <= 1:
|
||||
continue
|
||||
|
||||
label_id = self.label_dict[label] # label id
|
||||
sent_id = [] # sentence ids
|
||||
for word in sent:
|
||||
if word in self.word_dict:
|
||||
sent_id.append(self.word_dict[word])
|
||||
else:
|
||||
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL])
|
||||
data.append([sent_id, label_id])
|
||||
|
||||
# save data
|
||||
with open(save_path, "wb") as f:
|
||||
_pickle.dump(data, f)
|
||||
|
||||
|
||||
class LMPreprocess(BasePreprocess):
|
||||
def __init__(self, data, pickle_path):
|
||||
super(LMPreprocess, self).__init__(data, pickle_path)
|
||||
import _pickle
|
||||
import os
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
|
||||
|
||||
# the first vocab in dict with the index = 5
|
||||
|
||||
def save_pickle(obj, pickle_path, file_name):
|
||||
with open(os.path.join(pickle_path, file_name), "wb") as f:
|
||||
_pickle.dump(obj, f)
|
||||
print("{} saved. ".format(file_name))
|
||||
|
||||
|
||||
def load_pickle(pickle_path, file_name):
|
||||
with open(os.path.join(pickle_path, file_name), "rb") as f:
|
||||
obj = _pickle.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def pickle_exist(pickle_path, pickle_name):
|
||||
"""
|
||||
:param pickle_path: the directory of target pickle file
|
||||
:param pickle_name: the filename of target pickle file
|
||||
:return: True if file exists else False
|
||||
"""
|
||||
if not os.path.exists(pickle_path):
|
||||
os.makedirs(pickle_path)
|
||||
file_name = os.path.join(pickle_path, pickle_name)
|
||||
if os.path.exists(file_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class BasePreprocess(object):
|
||||
|
||||
def __init__(self, data, pickle_path):
|
||||
super(BasePreprocess, self).__init__()
|
||||
# self.data = data
|
||||
self.pickle_path = pickle_path
|
||||
if not self.pickle_path.endswith('/'):
|
||||
self.pickle_path = self.pickle_path + '/'
|
||||
|
||||
|
||||
class POSPreprocess(BasePreprocess):
|
||||
"""
|
||||
This class are used to preprocess the POS Tag datasets.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, pickle_path="./", train_dev_split=0):
|
||||
"""
|
||||
Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||
from labels/classes to index, from index to labels/classes.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:param pickle_path: str, the directory to the pickle files. Default: "./"
|
||||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
|
||||
|
||||
"""
|
||||
super(POSPreprocess, self).__init__(data, pickle_path)
|
||||
|
||||
self.pickle_path = pickle_path
|
||||
|
||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
|
||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
self.label2index = load_pickle(self.pickle_path, "class2id.pkl")
|
||||
else:
|
||||
self.word2index, self.label2index = self.build_dict(data)
|
||||
save_pickle(self.word2index, self.pickle_path, "word2id.pkl")
|
||||
save_pickle(self.label2index, self.pickle_path, "class2id.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2word.pkl"):
|
||||
index2word = self.build_reverse_dict(self.word2index)
|
||||
save_pickle(index2word, self.pickle_path, "id2word.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2class.pkl"):
|
||||
index2label = self.build_reverse_dict(self.label2index)
|
||||
save_pickle(index2label, self.pickle_path, "id2class.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "data_train.pkl"):
|
||||
data_train = self.to_index(data)
|
||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"):
|
||||
data_dev = data_train[: int(len(data_train) * train_dev_split)]
|
||||
save_pickle(data_dev, self.pickle_path, "data_dev.pkl")
|
||||
save_pickle(data_train, self.pickle_path, "data_train.pkl")
|
||||
|
||||
def build_dict(self, data):
|
||||
"""
|
||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return word2index: dict of {str, int}
|
||||
label2index: dict of {str, int}
|
||||
"""
|
||||
label2index = {}
|
||||
word2index = DEFAULT_WORD_TO_INDEX
|
||||
for example in data:
|
||||
for word, label in zip(example[0], example[1]):
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
return word2index, label2index
|
||||
|
||||
def build_reverse_dict(self, word_dict):
|
||||
id2word = {word_dict[w]: w for w in word_dict}
|
||||
return id2word
|
||||
|
||||
def to_index(self, data):
|
||||
"""
|
||||
Convert word strings and label strings into indices.
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return data_index: the shape of data, but each string is replaced by its corresponding index
|
||||
"""
|
||||
data_index = []
|
||||
for example in data:
|
||||
word_list = []
|
||||
label_list = []
|
||||
for word, label in zip(example[0], example[1]):
|
||||
word_list.append(self.word2index[word])
|
||||
label_list.append(self.label2index[label])
|
||||
data_index.append([word_list, label_list])
|
||||
return data_index
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2index)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.label2index)
|
||||
|
||||
|
||||
class ClassPreprocess(BasePreprocess):
|
||||
"""
|
||||
Pre-process the classification datasets.
|
||||
|
||||
Params:
|
||||
pickle_path - directory to save result of pre-processing
|
||||
Saves:
|
||||
word2id.pkl
|
||||
id2word.pkl
|
||||
class2id.pkl
|
||||
id2class.pkl
|
||||
embedding.pkl
|
||||
data_train.pkl
|
||||
data_dev.pkl
|
||||
data_test.pkl
|
||||
"""
|
||||
|
||||
def __init__(self, pickle_path):
|
||||
# super(ClassPreprocess, self).__init__(data, pickle_path)
|
||||
self.word_dict = None
|
||||
self.label_dict = None
|
||||
self.pickle_path = pickle_path # save directory
|
||||
|
||||
def process(self, data, save_name):
|
||||
"""
|
||||
Process data.
|
||||
|
||||
Params:
|
||||
data - nested list, data = [sample1, sample2, ...],
|
||||
sample = [sentence, label], sentence = [word1, word2, ...]
|
||||
save_name - name of processed data, such as data_train.pkl
|
||||
Returns:
|
||||
vocab_size - vocabulary size
|
||||
n_classes - number of classes
|
||||
"""
|
||||
self.build_dict(data)
|
||||
self.word2id()
|
||||
vocab_size = self.id2word()
|
||||
self.class2id()
|
||||
num_classes = self.id2class()
|
||||
self.embedding()
|
||||
self.data_generate(data, save_name)
|
||||
|
||||
return vocab_size, num_classes
|
||||
|
||||
def build_dict(self, data):
|
||||
"""Build vocabulary."""
|
||||
|
||||
# just read if word2id.pkl and class2id.pkl exists
|
||||
if self.pickle_exist("word2id.pkl") and \
|
||||
self.pickle_exist("class2id.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "word2id.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
self.word_dict = _pickle.load(f)
|
||||
file_name = os.path.join(self.pickle_path, "class2id.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
self.label_dict = _pickle.load(f)
|
||||
return
|
||||
|
||||
# build vocabulary from scratch if nothing exists
|
||||
self.word_dict = {
|
||||
DEFAULT_PADDING_LABEL: 0,
|
||||
DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2,
|
||||
DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
self.label_dict = {}
|
||||
|
||||
# collect every word and label
|
||||
for sent, label in data:
|
||||
if len(sent) <= 1:
|
||||
continue
|
||||
|
||||
if label not in self.label_dict:
|
||||
index = len(self.label_dict)
|
||||
self.label_dict[label] = index
|
||||
|
||||
for word in sent:
|
||||
if word not in self.word_dict:
|
||||
index = len(self.word_dict)
|
||||
self.word_dict[word[0]] = index
|
||||
|
||||
def pickle_exist(self, pickle_name):
|
||||
"""
|
||||
Check whether a pickle file exists.
|
||||
|
||||
Params
|
||||
pickle_name: the filename of target pickle file
|
||||
Return
|
||||
True if file exists else False
|
||||
"""
|
||||
if not os.path.exists(self.pickle_path):
|
||||
os.makedirs(self.pickle_path)
|
||||
file_name = os.path.join(self.pickle_path, pickle_name)
|
||||
if os.path.exists(file_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def word2id(self):
|
||||
"""Save vocabulary of {word:id} mapping format."""
|
||||
# nothing will be done if word2id.pkl exists
|
||||
if self.pickle_exist("word2id.pkl"):
|
||||
return
|
||||
|
||||
file_name = os.path.join(self.pickle_path, "word2id.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(self.word_dict, f)
|
||||
|
||||
def id2word(self):
|
||||
"""Save vocabulary of {id:word} mapping format."""
|
||||
# nothing will be done if id2word.pkl exists
|
||||
if self.pickle_exist("id2word.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "id2word.pkl")
|
||||
with open(file_name, 'rb') as f:
|
||||
id2word_dict = _pickle.load(f)
|
||||
return len(id2word_dict)
|
||||
|
||||
id2word_dict = {self.word_dict[w]: w for w in self.word_dict}
|
||||
file_name = os.path.join(self.pickle_path, "id2word.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(id2word_dict, f)
|
||||
return len(id2word_dict)
|
||||
|
||||
def class2id(self):
|
||||
"""Save mapping of {class:id}."""
|
||||
# nothing will be done if class2id.pkl exists
|
||||
if self.pickle_exist("class2id.pkl"):
|
||||
return
|
||||
|
||||
file_name = os.path.join(self.pickle_path, "class2id.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(self.label_dict, f)
|
||||
|
||||
def id2class(self):
|
||||
"""Save mapping of {id:class}."""
|
||||
# nothing will be done if id2class.pkl exists
|
||||
if self.pickle_exist("id2class.pkl"):
|
||||
file_name = os.path.join(self.pickle_path, "id2class.pkl")
|
||||
with open(file_name, "rb") as f:
|
||||
id2class_dict = _pickle.load(f)
|
||||
return len(id2class_dict)
|
||||
|
||||
id2class_dict = {self.label_dict[c]: c for c in self.label_dict}
|
||||
file_name = os.path.join(self.pickle_path, "id2class.pkl")
|
||||
with open(file_name, "wb") as f:
|
||||
_pickle.dump(id2class_dict, f)
|
||||
return len(id2class_dict)
|
||||
|
||||
def embedding(self):
|
||||
"""Save embedding lookup table corresponding to vocabulary."""
|
||||
# nothing will be done if embedding.pkl exists
|
||||
if self.pickle_exist("embedding.pkl"):
|
||||
return
|
||||
|
||||
# retrieve vocabulary from pre-trained embedding (not implemented)
|
||||
|
||||
def data_generate(self, data_src, save_name):
|
||||
"""Convert dataset from text to digit."""
|
||||
|
||||
# nothing will be done if file exists
|
||||
save_path = os.path.join(self.pickle_path, save_name)
|
||||
if os.path.exists(save_path):
|
||||
return
|
||||
|
||||
data = []
|
||||
# for every sample
|
||||
for sent, label in data_src:
|
||||
if len(sent) <= 1:
|
||||
continue
|
||||
|
||||
label_id = self.label_dict[label] # label id
|
||||
sent_id = [] # sentence ids
|
||||
for word in sent:
|
||||
if word in self.word_dict:
|
||||
sent_id.append(self.word_dict[word])
|
||||
else:
|
||||
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL])
|
||||
data.append([sent_id, label_id])
|
||||
|
||||
# save data
|
||||
with open(save_path, "wb") as f:
|
||||
_pickle.dump(data, f)
|
||||
|
||||
|
||||
class LMPreprocess(BasePreprocess):
|
||||
def __init__(self, data, pickle_path):
|
||||
super(LMPreprocess, self).__init__(data, pickle_path)
|
||||
|
||||
|
||||
def infer_preprocess(pickle_path, data):
|
||||
"""
|
||||
Preprocess over inference data.
|
||||
Transform three-level list of strings into that of index.
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
[word_21, word_22, ...],
|
||||
...
|
||||
]
|
||||
"""
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
data_index = []
|
||||
for example in data:
|
||||
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example])
|
||||
return data_index
|
||||
|
@ -9,17 +9,12 @@ class SeqLabeling(BaseModel):
|
||||
PyTorch Network for sequence labeling
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim,
|
||||
rnn_num_layer,
|
||||
num_classes,
|
||||
vocab_size,
|
||||
word_emb_dim=100,
|
||||
init_emb=None,
|
||||
rnn_mode="gru",
|
||||
bi_direction=False,
|
||||
dropout=0.5,
|
||||
use_crf=True):
|
||||
def __init__(self, args):
|
||||
super(SeqLabeling, self).__init__()
|
||||
vocab_size = args["vocab_size"]
|
||||
word_emb_dim = args["word_emb_dim"]
|
||||
hidden_dim = args["rnn_hidden_units"]
|
||||
num_classes = args["num_classes"]
|
||||
|
||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
|
||||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
|
||||
@ -29,7 +24,7 @@ class SeqLabeling(BaseModel):
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: LongTensor, [batch_size, mex_len]
|
||||
:return y: [batch_size, tag_size, tag_size]
|
||||
:return y: [batch_size, mex_len, tag_size]
|
||||
"""
|
||||
x = self.Embedding(x)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
@ -64,7 +59,7 @@ class SeqLabeling(BaseModel):
|
||||
|
||||
def prediction(self, x, seq_length):
|
||||
"""
|
||||
:param x: FloatTensor, [batch_size, tag_size, tag_size]
|
||||
:param x: FloatTensor, [batch_size, max_len, tag_size]
|
||||
:param seq_length: int
|
||||
:return prediction: list of tuple of (decode path(list), best score)
|
||||
"""
|
||||
|
@ -13,7 +13,7 @@ class Lstm(nn.Module):
|
||||
bidirectional : If True, becomes a bidirectional RNN. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.5, bidirectional=False):
|
||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False):
|
||||
super(Lstm, self).__init__()
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
|
@ -74,3 +74,9 @@ save_dev_input = false
|
||||
save_loss = true
|
||||
batch_size = 1
|
||||
pickle_path = "./data_for_tests/"
|
||||
rnn_hidden_units = 100
|
||||
rnn_layers = 1
|
||||
rnn_bi_direction = true
|
||||
word_emb_dim = 100
|
||||
dropout = 0.5
|
||||
use_crf = true
|
||||
|
2
test/data_for_tests/people_infer.txt
Normal file
2
test/data_for_tests/people_infer.txt
Normal file
@ -0,0 +1,2 @@
|
||||
迈向充满希望的新世纪——一九九八年新年讲话
|
||||
(附图片1张)
|
@ -4,8 +4,8 @@ sys.path.append("..")
|
||||
|
||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.action.trainer import POSTrainer
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||
from fastNLP.loader.preprocess import POSPreprocess
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
|
||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
|
||||
from fastNLP.saver.model_saver import ModelSaver
|
||||
from fastNLP.loader.model_loader import ModelLoader
|
||||
from fastNLP.action.tester import POSTester
|
||||
@ -15,32 +15,49 @@ from fastNLP.action.inference import Inference
|
||||
data_name = "people.txt"
|
||||
data_path = "data_for_tests/people.txt"
|
||||
pickle_path = "data_for_tests"
|
||||
data_infer_path = "data_for_tests/people_infer.txt"
|
||||
|
||||
|
||||
def test_infer():
|
||||
def infer():
|
||||
# Load infer configuration, the same as test
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
|
||||
# fetch dictinary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
|
||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
|
||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
|
||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
|
||||
model = SeqLabeling(test_args)
|
||||
|
||||
# Dump trained parameters into the model
|
||||
ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model)
|
||||
ModelLoader.load_pytorch(model, "./saved_model.pkl")
|
||||
print("model loaded!")
|
||||
|
||||
# Data Loader
|
||||
pos_loader = POSDatasetLoader(data_name, data_path)
|
||||
infer_data = pos_loader.load_lines()
|
||||
|
||||
# Preprocessor
|
||||
POSPreprocess(infer_data, pickle_path)
|
||||
raw_data_loader = BaseLoader(data_name, data_infer_path)
|
||||
infer_data = raw_data_loader.load_lines()
|
||||
"""
|
||||
Transform strings into list of list of strings.
|
||||
[
|
||||
[word_11, word_12, ...],
|
||||
[word_21, word_22, ...],
|
||||
...
|
||||
]
|
||||
In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them.
|
||||
"""
|
||||
|
||||
# Inference interface
|
||||
infer = Inference()
|
||||
infer = Inference(pickle_path)
|
||||
results = infer.predict(model, infer_data)
|
||||
|
||||
print(results)
|
||||
print("Inference finished!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def train_test():
|
||||
# Config Loader
|
||||
train_args = ConfigSection()
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
|
||||
@ -58,10 +75,7 @@ if __name__ == "__main__":
|
||||
trainer = POSTrainer(train_args)
|
||||
|
||||
# Model
|
||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
|
||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
|
||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
|
||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
|
||||
model = SeqLabeling(train_args)
|
||||
|
||||
# Start training
|
||||
trainer.train(model)
|
||||
@ -75,13 +89,10 @@ if __name__ == "__main__":
|
||||
del model, trainer, pos_loader
|
||||
|
||||
# Define the same model
|
||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
|
||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
|
||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
|
||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
|
||||
model = SeqLabeling(train_args)
|
||||
|
||||
# Dump trained parameters into the model
|
||||
ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model)
|
||||
ModelLoader.load_pytorch(model, "./saved_model.pkl")
|
||||
print("model loaded!")
|
||||
|
||||
# Load test configuration
|
||||
@ -97,3 +108,7 @@ if __name__ == "__main__":
|
||||
# print test results
|
||||
print(tester.show_matrices())
|
||||
print("model tested!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
infer()
|
||||
|
Loading…
Reference in New Issue
Block a user