mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
fix vocab
This commit is contained in:
parent
9c7f3cf261
commit
819c8f05be
@ -27,8 +27,8 @@ class Predictor(object):
|
||||
self.batch_output = []
|
||||
self.pickle_path = pickle_path
|
||||
self._task = task # one of ("seq_label", "text_classify")
|
||||
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
|
||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
|
||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")
|
||||
|
||||
def predict(self, network, data):
|
||||
"""Perform inference using the trained model.
|
||||
@ -82,7 +82,7 @@ class Predictor(object):
|
||||
:return data_set: a DataSet instance.
|
||||
"""
|
||||
assert isinstance(data, list)
|
||||
return create_dataset_from_lists(data, self.word2index, has_target=False)
|
||||
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
|
||||
|
||||
def prepare_output(self, data):
|
||||
"""Transform list of batch outputs into strings."""
|
||||
@ -97,14 +97,14 @@ class Predictor(object):
|
||||
results = []
|
||||
for batch in batch_outputs:
|
||||
for example in np.array(batch):
|
||||
results.append([self.index2label[int(x)] for x in example])
|
||||
results.append([self.label_vocab.to_word(int(x)) for x in example])
|
||||
return results
|
||||
|
||||
def _text_classify_prepare_output(self, batch_outputs):
|
||||
results = []
|
||||
for batch_out in batch_outputs:
|
||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
|
||||
results.extend([self.index2label[i] for i in idx])
|
||||
results.extend([self.label_vocab.to_word(i) for i in idx])
|
||||
return results
|
||||
|
||||
|
||||
|
@ -69,7 +69,7 @@ class FastNLP(object):
|
||||
:param model_dir: this directory should contain the following files:
|
||||
1. a pre-trained model
|
||||
2. a config file
|
||||
3. "id2class.pkl"
|
||||
3. "class2id.pkl"
|
||||
4. "word2id.pkl"
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
@ -99,10 +99,10 @@ class FastNLP(object):
|
||||
print("Restore model hyper-parameters {}".format(str(model_args.data)))
|
||||
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(self.model_dir, "id2class.pkl")
|
||||
model_args["num_classes"] = len(index2label)
|
||||
word_vocab = load_pickle(self.model_dir, "word2id.pkl")
|
||||
model_args["vocab_size"] = len(word_vocab)
|
||||
label_vocab = load_pickle(self.model_dir, "class2id.pkl")
|
||||
model_args["num_classes"] = len(label_vocab)
|
||||
|
||||
# Construct the model
|
||||
model = model_class(model_args)
|
||||
|
@ -32,7 +32,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ def test():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# load dev data
|
||||
|
@ -33,7 +33,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
@ -105,7 +105,7 @@ def test():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# load dev data
|
||||
|
@ -4,6 +4,7 @@ import unittest
|
||||
from fastNLP.core.predictor import Predictor
|
||||
from fastNLP.core.preprocess import save_pickle
|
||||
from fastNLP.models.sequence_modeling import SeqLabeling
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase):
|
||||
['a', 'b', 'c', 'd', '$'],
|
||||
['!', 'b', 'c', 'd', 'e']
|
||||
]
|
||||
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
|
||||
vocab = Vocabulary()
|
||||
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
|
||||
class_vocab = Vocabulary()
|
||||
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4}
|
||||
|
||||
os.system("mkdir save")
|
||||
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
|
||||
save_pickle(class_vocab, "./save/", "class2id.pkl")
|
||||
save_pickle(vocab, "./save/", "word2id.pkl")
|
||||
|
||||
model = SeqLabeling(model_args)
|
||||
|
@ -38,7 +38,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
|
@ -27,7 +27,7 @@ def infer():
|
||||
# fetch dictionary size and number of labels from pickle files
|
||||
word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
test_args["vocab_size"] = len(word2index)
|
||||
index2label = load_pickle(pickle_path, "id2class.pkl")
|
||||
index2label = load_pickle(pickle_path, "class2id.pkl")
|
||||
test_args["num_classes"] = len(index2label)
|
||||
|
||||
# Define the same model
|
||||
|
Loading…
Reference in New Issue
Block a user