mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
commit
9733249b5e
@ -27,8 +27,8 @@ class Predictor(object):
|
||||
self.batch_output = []
|
||||
self.pickle_path = pickle_path
|
||||
self._task = task # one of ("seq_label", "text_classify")
|
||||
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
|
||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
|
||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
|
||||
def predict(self, network, data):
|
||||
"""Perform inference using the trained model.
|
||||
@ -82,7 +82,7 @@ class Predictor(object):
|
||||
:return data_set: a DataSet instance.
|
||||
"""
|
||||
assert isinstance(data, list)
|
||||
return create_dataset_from_lists(data, self.word2index, has_target=False)
|
||||
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
|
||||
|
||||
def prepare_output(self, data):
|
||||
"""Transform list of batch outputs into strings."""
|
||||
@ -97,14 +97,14 @@ class Predictor(object):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
for example in np.array(batch):
|
||||
results.append([self.index2label[int(x)] for x in example])
|
||||
results.append([self.label_vocab.to_word(int(x)) for x in example])
|
||||
return results
|
||||
|
||||
def _text_classify_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch_out in batch_outputs:
|
||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
|
||||
results.extend([self.index2label[i] for i in idx])
|
||||
results.extend([self.label_vocab.to_word(i) for i in idx])
|
||||
return results
|
||||
|
||||
|
||||
|
@ -6,16 +6,7 @@ import numpy as np
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
# the first vocab in dict with the index = 5
|
||||
@ -68,24 +59,22 @@ class BasePreprocess(object):
|
||||
|
||||
- "word2id.pkl", a mapping from words(tokens) to indices
|
||||
- "id2word.pkl", a reversed dictionary
|
||||
- "label2id.pkl", a dictionary on labels
|
||||
- "id2label.pkl", a reversed dictionary on labels
|
||||
|
||||
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
|
||||
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.word2index = None
|
||||
self.label2index = None
|
||||
self.data_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary()
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2index)
|
||||
return len(self.data_vocab)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.label2index)
|
||||
return len(self.label_vocab)
|
||||
|
||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
|
||||
"""Main pre-processing pipeline.
|
||||
@ -102,20 +91,14 @@ class BasePreprocess(object):
|
||||
"""
|
||||
|
||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
|
||||
self.word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
self.label2index = load_pickle(pickle_path, "class2id.pkl")
|
||||
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
|
||||
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
|
||||
else:
|
||||
self.word2index, self.label2index = self.build_dict(train_dev_data)
|
||||
save_pickle(self.word2index, pickle_path, "word2id.pkl")
|
||||
save_pickle(self.label2index, pickle_path, "class2id.pkl")
|
||||
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data)
|
||||
save_pickle(self.data_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(self.label_vocab, pickle_path, "class2id.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2word.pkl"):
|
||||
index2word = self.build_reverse_dict(self.word2index)
|
||||
save_pickle(index2word, pickle_path, "id2word.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2class.pkl"):
|
||||
index2label = self.build_reverse_dict(self.label2index)
|
||||
save_pickle(index2label, pickle_path, "id2class.pkl")
|
||||
self.build_reverse_dict()
|
||||
|
||||
train_set = []
|
||||
dev_set = []
|
||||
@ -125,13 +108,13 @@ class BasePreprocess(object):
|
||||
split = int(len(train_dev_data) * train_dev_split)
|
||||
data_dev = train_dev_data[: split]
|
||||
data_train = train_dev_data[split:]
|
||||
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
|
||||
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
|
||||
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab)
|
||||
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab)
|
||||
|
||||
save_pickle(dev_set, pickle_path, "data_dev.pkl")
|
||||
print("{} of the training data is split for validation. ".format(train_dev_split))
|
||||
else:
|
||||
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
|
||||
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab)
|
||||
save_pickle(train_set, pickle_path, "data_train.pkl")
|
||||
else:
|
||||
train_set = load_pickle(pickle_path, "data_train.pkl")
|
||||
@ -143,8 +126,8 @@ class BasePreprocess(object):
|
||||
# cross validation
|
||||
data_cv = self.cv_split(train_dev_data, n_fold)
|
||||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
|
||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
|
||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
|
||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab)
|
||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab)
|
||||
save_pickle(
|
||||
data_train_cv, pickle_path,
|
||||
"data_train_{}.pkl".format(i))
|
||||
@ -165,7 +148,7 @@ class BasePreprocess(object):
|
||||
test_set = []
|
||||
if test_data is not None:
|
||||
if not pickle_exist(pickle_path, "data_test.pkl"):
|
||||
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
|
||||
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab)
|
||||
save_pickle(test_set, pickle_path, "data_test.pkl")
|
||||
|
||||
# return preprocessed results
|
||||
@ -180,28 +163,15 @@ class BasePreprocess(object):
|
||||
return tuple(results)
|
||||
|
||||
def build_dict(self, data):
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
for example in data:
|
||||
for word in example[0]:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
label = example[1]
|
||||
if isinstance(label, str):
|
||||
# label is a string
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
elif isinstance(label, list):
|
||||
# label is a list of strings
|
||||
for single_label in label:
|
||||
if single_label not in label2index:
|
||||
label2index[single_label] = len(label2index)
|
||||
return word2index, label2index
|
||||
word, label = example
|
||||
self.data_vocab.update(word)
|
||||
self.label_vocab.update(label)
|
||||
return self.data_vocab, self.label_vocab
|
||||
|
||||
|
||||
def build_reverse_dict(self, word_dict):
|
||||
id2word = {word_dict[w]: w for w in word_dict}
|
||||
return id2word
|
||||
def build_reverse_dict(self):
|
||||
self.data_vocab.build_reverse_vocab()
|
||||
self.label_vocab.build_reverse_vocab()
|
||||
|
||||
def data_split(self, data, train_dev_split):
|
||||
"""Split data into train and dev set."""
|
||||
|
124
fastNLP/core/vocabulary.py
Normal file
124
fastNLP/core/vocabulary.py
Normal file
@ -0,0 +1,124 @@
|
||||
from copy import deepcopy
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
|
||||
def isiterable(p_object):
|
||||
try:
|
||||
it = iter(p_object)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Use for word and index one to one mapping
|
||||
|
||||
Example::
|
||||
|
||||
vocab = Vocabulary()
|
||||
word_list = "this is a word list".split()
|
||||
vocab.update(word_list)
|
||||
vocab["word"]
|
||||
vocab.to_word(5)
|
||||
"""
|
||||
def __init__(self, need_default=True):
|
||||
"""
|
||||
:param bool need_default: set if the Vocabulary has default labels reserved.
|
||||
"""
|
||||
if need_default:
|
||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
|
||||
self.padding_label = DEFAULT_PADDING_LABEL
|
||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL
|
||||
else:
|
||||
self.word2idx = {}
|
||||
self.padding_label = None
|
||||
self.unknown_label = None
|
||||
|
||||
self.has_default = need_default
|
||||
self.idx2word = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word2idx)
|
||||
|
||||
def update(self, word):
|
||||
"""add word or list of words into Vocabulary
|
||||
|
||||
:param word: a list of str or str
|
||||
"""
|
||||
if not isinstance(word, str) and isiterable(word):
|
||||
# it's a nested list
|
||||
for w in word:
|
||||
self.update(w)
|
||||
else:
|
||||
# it's a word to be added
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = len(self)
|
||||
if self.idx2word is not None:
|
||||
self.idx2word = None
|
||||
|
||||
|
||||
def __getitem__(self, w):
|
||||
"""To support usage like::
|
||||
|
||||
vocab[w]
|
||||
"""
|
||||
if w in self.word2idx:
|
||||
return self.word2idx[w]
|
||||
else:
|
||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
|
||||
|
||||
def to_index(self, w):
|
||||
""" like to_index(w) function, turn a word to the index
|
||||
if w is not in Vocabulary, return the unknown label
|
||||
|
||||
:param str w:
|
||||
"""
|
||||
return self[w]
|
||||
|
||||
def unknown_idx(self):
|
||||
if self.unknown_label is None:
|
||||
return None
|
||||
return self.word2idx[self.unknown_label]
|
||||
|
||||
def padding_idx(self):
|
||||
if self.padding_label is None:
|
||||
return None
|
||||
return self.word2idx[self.padding_label]
|
||||
|
||||
def build_reverse_vocab(self):
|
||||
"""build 'index to word' dict based on 'word to index' dict
|
||||
"""
|
||||
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
|
||||
|
||||
def to_word(self, idx):
|
||||
"""given a word's index, return the word itself
|
||||
|
||||
:param int idx:
|
||||
"""
|
||||
if self.idx2word is None:
|
||||
self.build_reverse_vocab()
|
||||
return self.idx2word[idx]
|
||||
|
||||
def __getstate__(self):
|
||||
"""use to prepare data for pickle
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
# no need to pickle idx2word as it can be constructed from word2idx
|
||||
del state['idx2word']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""use to restore state from pickle
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
self.idx2word = None
|
||||
|
||||
|
||||
|
@ -69,7 +69,7 @@ class FastNLP(object):
|
||||
:param model_dir: this directory should contain the following files:
|
||||
1. a pre-trained model
|
||||
2. a config file
|
||||
3. "id2class.pkl"
|
||||
3. "class2id.pkl"
|
||||
4. "word2id.pkl"
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
@ -99,10 +99,10 @@ class FastNLP(object):
|
||||
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(self.model_dir, "id2class.pkl")
|
||||
model_args["num_classes"] = len(index2label)
|
||||
word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word_vocab)
|
||||
label_vocab = load_pickle(self.model_dir, "class2id.pkl")
|
||||
model_args["num_classes"] = len(label_vocab)
|
||||
|
||||
# Construct the model
|
||||
model = model_class(model_args)
|
||||
|
@ -32,7 +32,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ def test():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# load dev data
|
||||
|
@ -33,7 +33,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
@ -105,7 +105,7 @@ def test():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# load dev data
|
||||
|
@ -4,6 +4,7 @@ import unittest
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase):
|
||||
['a', 'b', 'c', 'd', '$'],
|
||||
['!', 'b', 'c', 'd', 'e']
|
||||
]
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
|
||||
vocab = Vocabulary()
|
||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
class_vocab = Vocabulary()
|
||||
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4}
|
||||
|
||||
os.system("mkdir save")
|
||||
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
|
||||
save_pickle(class_vocab, "./save/", "class2id.pkl")
|
||||
save_pickle(vocab, "./save/", "word2id.pkl")
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
|
31
test/core/test_vocab.py
Normal file
31
test/core/test_vocab.py
Normal file
@ -0,0 +1,31 @@
|
||||
import unittest
|
||||
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
|
||||
|
||||
class TestVocabulary(unittest.TestCase):
|
||||
def test_vocab(self):
|
||||
import _pickle as pickle
|
||||
import os
|
||||
vocab = Vocabulary()
|
||||
filename = 'vocab'
|
||||
vocab.update(filename)
|
||||
vocab.update([filename, ['a'], [['b']], ['c']])
|
||||
idx = vocab[filename]
|
||||
before_pic = (vocab.to_word(idx), vocab[filename])
|
||||
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(vocab, f)
|
||||
with open(filename, 'rb') as f:
|
||||
vocab = pickle.load(f)
|
||||
os.remove(filename)
|
||||
|
||||
vocab.build_reverse_vocab()
|
||||
after_pic = (vocab.to_word(idx), vocab[filename])
|
||||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
|
||||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
|
||||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'}
|
||||
self.assertEqual(before_pic, after_pic)
|
||||
self.assertDictEqual(TRUE_DICT, vocab.word2idx)
|
||||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -38,7 +38,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
|
@ -27,7 +27,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
|
Loading…
Reference in New Issue
Block a user