Merge remote-tracking branch 'origin/master'

This commit is contained in:
choosewhatulike 2018-07-25 23:41:45 +08:00
commit febe27b5bb
3 changed files with 124 additions and 58 deletions

View File

@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader):
class POSDatasetLoader(DatasetLoader): class POSDatasetLoader(DatasetLoader):
"""loader for pos data sets""" """Dataset Loader for POS Tag datasets.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
"""
def __init__(self, data_name, data_path): def __init__(self, data_name, data_path):
super(POSDatasetLoader, self).__init__(data_name, data_path) super(POSDatasetLoader, self).__init__(data_name, data_path)
@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader):
return line return line
def load_lines(self): def load_lines(self):
assert (os.path.exists(self.data_path)) """
:return data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
"""
with open(self.data_path, "r", encoding="utf-8") as f: with open(self.data_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
return lines return self.parse(lines)
@staticmethod
def parse(lines):
data = []
sentence = []
for line in lines:
line = line.strip()
if len(line) > 1:
sentence.append(line.split('\t'))
else:
words = []
labels = []
for tokens in sentence:
words.append(tokens[0])
labels.append(tokens[1])
data.append([words, labels])
sentence = []
if len(sentence) != 0:
words = []
labels = []
for tokens in sentence:
words.append(tokens[0])
labels.append(tokens[1])
data.append([words, labels])
return data
class ClassDatasetLoader(DatasetLoader): class ClassDatasetLoader(DatasetLoader):
@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader):
with open(self.data_path, "r", encoding="utf=8") as f: with open(self.data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines()) text = " ".join(f.readlines())
return text.strip().split() return text.strip().split()
if __name__ == "__main__":
data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines()
for example in data:
for w, l in zip(example[0], example[1]):
print(w, l)

View File

@ -28,33 +28,24 @@ class BasePreprocess(object):
class POSPreprocess(BasePreprocess): class POSPreprocess(BasePreprocess):
""" """
This class are used to preprocess the pos datasets. This class are used to preprocess the pos datasets.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
""" """
def __init__(self, data, pickle_path, train_dev_split=0): def __init__(self, data, pickle_path="./", train_dev_split=0):
""" """
Preprocess pipeline, including building mapping from words to index, from index to words, Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes. from labels/classes to index, from index to labels/classes.
:param data: :param data: three-level list
:param pickle_path: [
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:param pickle_path: str, the directory to the pickle files. Default: "./"
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
To do: To do:
1. use @contextmanager to handle pickle dumps and loads 1. simplify __init__
""" """
super(POSPreprocess, self).__init__(data, pickle_path) super(POSPreprocess, self).__init__(data, pickle_path)
@ -75,6 +66,7 @@ class POSPreprocess(BasePreprocess):
else: else:
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
_pickle.dump(self.label2index, f) _pickle.dump(self.label2index, f)
#something will be wrong if word2id.pkl is found but class2id.pkl is not found
if not self.pickle_exist("id2word.pkl"): if not self.pickle_exist("id2word.pkl"):
index2word = self.build_reverse_dict(self.word2index) index2word = self.build_reverse_dict(self.word2index)
@ -98,25 +90,23 @@ class POSPreprocess(BasePreprocess):
def build_dict(self, data): def build_dict(self, data):
""" """
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: list of list [word, label] :param data: three-level list
:return word2index: dict of (str, int) [
label2index: dict of (str, int) [ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return word2index: dict of {str, int}
label2index: dict of {str, int}
""" """
label2index = {} label2index = {}
word2index = DEFAULT_WORD_TO_INDEX word2index = DEFAULT_WORD_TO_INDEX
for line in data: for example in data:
line = line.strip() for word, label in zip(example[0], example[1]):
if len(line) <= 1: if word not in word2index:
continue word2index[word] = len(word2index)
tokens = line.split('\t') if label not in label2index:
label2index[label] = len(label2index)
if tokens[0] not in word2index:
# add (word, index) into the dict
word2index[tokens[0]] = len(word2index)
# for label in tokens[1: ]:
if tokens[1] not in label2index:
label2index[tokens[1]] = len(label2index)
return word2index, label2index return word2index, label2index
def pickle_exist(self, pickle_name): def pickle_exist(self, pickle_name):
@ -139,24 +129,31 @@ class POSPreprocess(BasePreprocess):
def to_index(self, data): def to_index(self, data):
""" """
Convert word strings and label strings into indices. Convert word strings and label strings into indices.
:param data: list of str. Each string is a line, described above. :param data: three-level list
:return data_index: list of tuple (word index, label index) [
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return data_index: the shape of data, but each string is replaced by its corresponding index
""" """
data_train = [] data_index = []
sentence = [] for example in data:
for w in data: word_list = []
w = w.strip() label_list = []
if len(w) <= 1: for word, label in zip(example[0], example[1]):
wid = [] word_list.append(self.word2index[word])
lid = [] label_list.append(self.label2index[label])
for i in range(len(sentence)): data_index.append([word_list, label_list])
wid.append(self.word2index[sentence[i][0]]) return data_index
lid.append(self.label2index[sentence[i][1]])
data_train.append((wid, lid)) @property
sentence = [] def vocab_size(self):
continue return len(self.word2index)
sentence.append(w.split('\t'))
return data_train @property
def num_classes(self):
return len(self.label2index)
class ClassPreprocess(BasePreprocess): class ClassPreprocess(BasePreprocess):

View File

@ -1,9 +1,23 @@
import unittest import unittest
from fastNLP.loader.dataset_loader import POSDatasetLoader
class MyTestCase(unittest.TestCase):
def test_something(self): class TestPreprocess(unittest.TestCase):
self.assertEqual(True, False) def test_case_1(self):
data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]],
["Hello", "world", "!"], ["T", "F", "F"]]
pickle_path = "./data_for_tests/"
# POSPreprocess(data, pickle_path)
class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
lines = data.split("\n")
answer = POSDatasetLoader.parse(lines)
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
self.assertListEqual(answer, truth, "POS Dataset Loader")
if __name__ == '__main__': if __name__ == '__main__':