add vocabulary into preprocessor

This commit is contained in:
yunfan 2018-09-18 16:43:56 +08:00
parent 3f4544759d
commit 9c7f3cf261
4 changed files with 52 additions and 130 deletions

View File

@ -6,16 +6,7 @@ import numpy as np
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
'<reserved-3>',
'<reserved-4>'] # dict index = 2~4
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
from fastNLP.core.vocabulary import Vocabulary
# the first vocab in dict with the index = 5
@ -68,24 +59,22 @@ class BasePreprocess(object):
- "word2id.pkl", a mapping from words(tokens) to indices
- "id2word.pkl", a reversed dictionary
- "label2id.pkl", a dictionary on labels
- "id2label.pkl", a reversed dictionary on labels
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
"""
def __init__(self):
self.word2index = None
self.label2index = None
self.data_vocab = Vocabulary()
self.label_vocab = Vocabulary()
@property
def vocab_size(self):
return len(self.word2index)
return len(self.data_vocab)
@property
def num_classes(self):
return len(self.label2index)
return len(self.label_vocab)
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
"""Main pre-processing pipeline.
@ -102,20 +91,14 @@ class BasePreprocess(object):
"""
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
self.word2index = load_pickle(pickle_path, "word2id.pkl")
self.label2index = load_pickle(pickle_path, "class2id.pkl")
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
else:
self.word2index, self.label2index = self.build_dict(train_dev_data)
save_pickle(self.word2index, pickle_path, "word2id.pkl")
save_pickle(self.label2index, pickle_path, "class2id.pkl")
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data)
save_pickle(self.data_vocab, pickle_path, "word2id.pkl")
save_pickle(self.label_vocab, pickle_path, "class2id.pkl")
if not pickle_exist(pickle_path, "id2word.pkl"):
index2word = self.build_reverse_dict(self.word2index)
save_pickle(index2word, pickle_path, "id2word.pkl")
if not pickle_exist(pickle_path, "id2class.pkl"):
index2label = self.build_reverse_dict(self.label2index)
save_pickle(index2label, pickle_path, "id2class.pkl")
self.build_reverse_dict()
train_set = []
dev_set = []
@ -125,13 +108,13 @@ class BasePreprocess(object):
split = int(len(train_dev_data) * train_dev_split)
data_dev = train_dev_data[: split]
data_train = train_dev_data[split:]
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab)
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab)
save_pickle(dev_set, pickle_path, "data_dev.pkl")
print("{} of the training data is split for validation. ".format(train_dev_split))
else:
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab)
save_pickle(train_set, pickle_path, "data_train.pkl")
else:
train_set = load_pickle(pickle_path, "data_train.pkl")
@ -143,8 +126,8 @@ class BasePreprocess(object):
# cross validation
data_cv = self.cv_split(train_dev_data, n_fold)
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab)
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab)
save_pickle(
data_train_cv, pickle_path,
"data_train_{}.pkl".format(i))
@ -165,7 +148,7 @@ class BasePreprocess(object):
test_set = []
if test_data is not None:
if not pickle_exist(pickle_path, "data_test.pkl"):
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab)
save_pickle(test_set, pickle_path, "data_test.pkl")
# return preprocessed results
@ -180,28 +163,15 @@ class BasePreprocess(object):
return tuple(results)
def build_dict(self, data):
label2index = DEFAULT_WORD_TO_INDEX.copy()
word2index = DEFAULT_WORD_TO_INDEX.copy()
for example in data:
for word in example[0]:
if word not in word2index:
word2index[word] = len(word2index)
label = example[1]
if isinstance(label, str):
# label is a string
if label not in label2index:
label2index[label] = len(label2index)
elif isinstance(label, list):
# label is a list of strings
for single_label in label:
if single_label not in label2index:
label2index[single_label] = len(label2index)
return word2index, label2index
word, label = example
self.data_vocab.update(word)
self.label_vocab.update(label)
return self.data_vocab, self.label_vocab
def build_reverse_dict(self, word_dict):
id2word = {word_dict[w]: w for w in word_dict}
return id2word
def build_reverse_dict(self):
self.data_vocab.build_reverse_vocab()
self.label_vocab.build_reverse_vocab()
def data_split(self, data, train_dev_split):
"""Split data into train and dev set."""

View File

@ -18,6 +18,16 @@ def isiterable(p_object):
return True
class Vocabulary(object):
"""Use for word and index one to one mapping
Example::
vocab = Vocabulary()
word_list = "this is a word list".split()
vocab.update(word_list)
vocab["word"]
vocab.to_word(5)
"""
def __init__(self):
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.padding_label = DEFAULT_PADDING_LABEL
@ -29,6 +39,8 @@ class Vocabulary(object):
def update(self, word):
"""add word or list of words into Vocabulary
:param word: a list of str or str
"""
if not isinstance(word, str) and isiterable(word):
# it's a nested list
@ -43,13 +55,22 @@ class Vocabulary(object):
def __getitem__(self, w):
""" like to_index(w) function, turn a word to the index
if w is not in Vocabulary, return the unknown label
"""To support usage like::
vocab[w]
"""
if w in self.word2idx:
return self.word2idx[w]
else:
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
def to_index(self, w):
""" like to_index(w) function, turn a word to the index
if w is not in Vocabulary, return the unknown label
:param str w:
"""
return self[w]
def unknown_idx(self):
return self.word2idx[self.unknown_label]
@ -58,10 +79,14 @@ class Vocabulary(object):
return self.word2idx[self.padding_label]
def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
def to_word(self, idx):
"""given a word's index, return the word itself
:param int idx:
"""
if self.idx2word is None:
self.build_reverse_vocab()

View File

@ -1,69 +0,0 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import unittest
import torch
from fastNLP.data.field import TextField, LabelField
from fastNLP.data.instance import Instance
from fastNLP.data.dataset import DataSet
from fastNLP.data.batch import Batch
class TestField(unittest.TestCase):
def check_batched_data_equal(self, data1, data2):
self.assertEqual(len(data1), len(data2))
for i in range(len(data1)):
self.assertTrue(data1[i].keys(), data2[i].keys())
for i in range(len(data1)):
for t1, t2 in zip(data1[i].values(), data2[i].values()):
self.assertTrue(torch.equal(t1, t2))
def test_batchiter(self):
texts = [
"i am a cat",
"this is a test of new batch",
"haha"
]
labels = [0, 1, 0]
# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text.split():
if tokens not in vocab:
vocab[tokens] = len(vocab)
# prepare input dataset
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text.split(), False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))
# use bacth to iterate dataset
batcher = Batch(data, SeqSampler(), 2)
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}]
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}]
for epoch in range(3):
test_x, test_y = [], []
for batch_x, batch_y in batcher:
test_x.append(batch_x)
test_y.append(batch_y)
self.check_batched_data_equal(TRUE_X, test_x)
self.check_batched_data_equal(TRUE_Y, test_y)
if __name__ == "__main__":
unittest.main()

View File

@ -1,9 +1,5 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import unittest
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
class TestVocabulary(unittest.TestCase):
def test_vocab(self):