mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
add vocabulary into preprocessor
This commit is contained in:
parent
3f4544759d
commit
9c7f3cf261
@ -6,16 +6,7 @@ import numpy as np
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
|
||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
||||
'<reserved-3>',
|
||||
'<reserved-4>'] # dict index = 2~4
|
||||
|
||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
||||
DEFAULT_RESERVED_LABEL[2]: 4}
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
# the first vocab in dict with the index = 5
|
||||
@ -68,24 +59,22 @@ class BasePreprocess(object):
|
||||
|
||||
- "word2id.pkl", a mapping from words(tokens) to indices
|
||||
- "id2word.pkl", a reversed dictionary
|
||||
- "label2id.pkl", a dictionary on labels
|
||||
- "id2label.pkl", a reversed dictionary on labels
|
||||
|
||||
These four pickle files are expected to be saved in the given pickle directory once they are constructed.
|
||||
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.word2index = None
|
||||
self.label2index = None
|
||||
self.data_vocab = Vocabulary()
|
||||
self.label_vocab = Vocabulary()
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2index)
|
||||
return len(self.data_vocab)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.label2index)
|
||||
return len(self.label_vocab)
|
||||
|
||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
|
||||
"""Main pre-processing pipeline.
|
||||
@ -102,20 +91,14 @@ class BasePreprocess(object):
|
||||
"""
|
||||
|
||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
|
||||
self.word2index = load_pickle(pickle_path, "word2id.pkl")
|
||||
self.label2index = load_pickle(pickle_path, "class2id.pkl")
|
||||
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
|
||||
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
|
||||
else:
|
||||
self.word2index, self.label2index = self.build_dict(train_dev_data)
|
||||
save_pickle(self.word2index, pickle_path, "word2id.pkl")
|
||||
save_pickle(self.label2index, pickle_path, "class2id.pkl")
|
||||
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data)
|
||||
save_pickle(self.data_vocab, pickle_path, "word2id.pkl")
|
||||
save_pickle(self.label_vocab, pickle_path, "class2id.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2word.pkl"):
|
||||
index2word = self.build_reverse_dict(self.word2index)
|
||||
save_pickle(index2word, pickle_path, "id2word.pkl")
|
||||
|
||||
if not pickle_exist(pickle_path, "id2class.pkl"):
|
||||
index2label = self.build_reverse_dict(self.label2index)
|
||||
save_pickle(index2label, pickle_path, "id2class.pkl")
|
||||
self.build_reverse_dict()
|
||||
|
||||
train_set = []
|
||||
dev_set = []
|
||||
@ -125,13 +108,13 @@ class BasePreprocess(object):
|
||||
split = int(len(train_dev_data) * train_dev_split)
|
||||
data_dev = train_dev_data[: split]
|
||||
data_train = train_dev_data[split:]
|
||||
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
|
||||
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)
|
||||
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab)
|
||||
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab)
|
||||
|
||||
save_pickle(dev_set, pickle_path, "data_dev.pkl")
|
||||
print("{} of the training data is split for validation. ".format(train_dev_split))
|
||||
else:
|
||||
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
|
||||
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab)
|
||||
save_pickle(train_set, pickle_path, "data_train.pkl")
|
||||
else:
|
||||
train_set = load_pickle(pickle_path, "data_train.pkl")
|
||||
@ -143,8 +126,8 @@ class BasePreprocess(object):
|
||||
# cross validation
|
||||
data_cv = self.cv_split(train_dev_data, n_fold)
|
||||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
|
||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
|
||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
|
||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab)
|
||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab)
|
||||
save_pickle(
|
||||
data_train_cv, pickle_path,
|
||||
"data_train_{}.pkl".format(i))
|
||||
@ -165,7 +148,7 @@ class BasePreprocess(object):
|
||||
test_set = []
|
||||
if test_data is not None:
|
||||
if not pickle_exist(pickle_path, "data_test.pkl"):
|
||||
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
|
||||
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab)
|
||||
save_pickle(test_set, pickle_path, "data_test.pkl")
|
||||
|
||||
# return preprocessed results
|
||||
@ -180,28 +163,15 @@ class BasePreprocess(object):
|
||||
return tuple(results)
|
||||
|
||||
def build_dict(self, data):
|
||||
label2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
word2index = DEFAULT_WORD_TO_INDEX.copy()
|
||||
for example in data:
|
||||
for word in example[0]:
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
label = example[1]
|
||||
if isinstance(label, str):
|
||||
# label is a string
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
elif isinstance(label, list):
|
||||
# label is a list of strings
|
||||
for single_label in label:
|
||||
if single_label not in label2index:
|
||||
label2index[single_label] = len(label2index)
|
||||
return word2index, label2index
|
||||
word, label = example
|
||||
self.data_vocab.update(word)
|
||||
self.label_vocab.update(label)
|
||||
return self.data_vocab, self.label_vocab
|
||||
|
||||
|
||||
def build_reverse_dict(self, word_dict):
|
||||
id2word = {word_dict[w]: w for w in word_dict}
|
||||
return id2word
|
||||
def build_reverse_dict(self):
|
||||
self.data_vocab.build_reverse_vocab()
|
||||
self.label_vocab.build_reverse_vocab()
|
||||
|
||||
def data_split(self, data, train_dev_split):
|
||||
"""Split data into train and dev set."""
|
||||
|
@ -18,6 +18,16 @@ def isiterable(p_object):
|
||||
return True
|
||||
|
||||
class Vocabulary(object):
|
||||
"""Use for word and index one to one mapping
|
||||
|
||||
Example::
|
||||
|
||||
vocab = Vocabulary()
|
||||
word_list = "this is a word list".split()
|
||||
vocab.update(word_list)
|
||||
vocab["word"]
|
||||
vocab.to_word(5)
|
||||
"""
|
||||
def __init__(self):
|
||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
|
||||
self.padding_label = DEFAULT_PADDING_LABEL
|
||||
@ -29,6 +39,8 @@ class Vocabulary(object):
|
||||
|
||||
def update(self, word):
|
||||
"""add word or list of words into Vocabulary
|
||||
|
||||
:param word: a list of str or str
|
||||
"""
|
||||
if not isinstance(word, str) and isiterable(word):
|
||||
# it's a nested list
|
||||
@ -43,13 +55,22 @@ class Vocabulary(object):
|
||||
|
||||
|
||||
def __getitem__(self, w):
|
||||
""" like to_index(w) function, turn a word to the index
|
||||
if w is not in Vocabulary, return the unknown label
|
||||
"""To support usage like::
|
||||
|
||||
vocab[w]
|
||||
"""
|
||||
if w in self.word2idx:
|
||||
return self.word2idx[w]
|
||||
else:
|
||||
return self.word2idx[DEFAULT_UNKNOWN_LABEL]
|
||||
|
||||
def to_index(self, w):
|
||||
""" like to_index(w) function, turn a word to the index
|
||||
if w is not in Vocabulary, return the unknown label
|
||||
|
||||
:param str w:
|
||||
"""
|
||||
return self[w]
|
||||
|
||||
def unknown_idx(self):
|
||||
return self.word2idx[self.unknown_label]
|
||||
@ -58,10 +79,14 @@ class Vocabulary(object):
|
||||
return self.word2idx[self.padding_label]
|
||||
|
||||
def build_reverse_vocab(self):
|
||||
"""build 'index to word' dict based on 'word to index' dict
|
||||
"""
|
||||
self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
|
||||
|
||||
def to_word(self, idx):
|
||||
"""given a word's index, return the word itself
|
||||
|
||||
:param int idx:
|
||||
"""
|
||||
if self.idx2word is None:
|
||||
self.build_reverse_vocab()
|
||||
|
@ -1,69 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
from fastNLP.data.field import TextField, LabelField
|
||||
from fastNLP.data.instance import Instance
|
||||
from fastNLP.data.dataset import DataSet
|
||||
from fastNLP.data.batch import Batch
|
||||
|
||||
|
||||
|
||||
class TestField(unittest.TestCase):
|
||||
def check_batched_data_equal(self, data1, data2):
|
||||
self.assertEqual(len(data1), len(data2))
|
||||
for i in range(len(data1)):
|
||||
self.assertTrue(data1[i].keys(), data2[i].keys())
|
||||
for i in range(len(data1)):
|
||||
for t1, t2 in zip(data1[i].values(), data2[i].values()):
|
||||
self.assertTrue(torch.equal(t1, t2))
|
||||
|
||||
def test_batchiter(self):
|
||||
texts = [
|
||||
"i am a cat",
|
||||
"this is a test of new batch",
|
||||
"haha"
|
||||
]
|
||||
labels = [0, 1, 0]
|
||||
|
||||
# prepare vocabulary
|
||||
vocab = {}
|
||||
for text in texts:
|
||||
for tokens in text.split():
|
||||
if tokens not in vocab:
|
||||
vocab[tokens] = len(vocab)
|
||||
|
||||
# prepare input dataset
|
||||
data = DataSet()
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text.split(), False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
def __call__(self, dataset):
|
||||
return list(range(len(dataset)))
|
||||
|
||||
# use bacth to iterate dataset
|
||||
batcher = Batch(data, SeqSampler(), 2)
|
||||
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}]
|
||||
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}]
|
||||
for epoch in range(3):
|
||||
test_x, test_y = [], []
|
||||
for batch_x, batch_y in batcher:
|
||||
test_x.append(batch_x)
|
||||
test_y.append(batch_y)
|
||||
self.check_batched_data_equal(TRUE_X, test_x)
|
||||
self.check_batched_data_equal(TRUE_Y, test_y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,9 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
import unittest
|
||||
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
|
||||
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
|
||||
|
||||
class TestVocabulary(unittest.TestCase):
|
||||
def test_vocab(self):
|
||||
|
Loading…
Reference in New Issue
Block a user