Merge remote-tracking branch 'origin/master'

This commit is contained in:
choosewhatulike 2018-07-25 23:41:45 +08:00
commit febe27b5bb
3 changed files with 124 additions and 58 deletions

View File

@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader):
class POSDatasetLoader(DatasetLoader):
"""loader for pos data sets"""
"""Dataset Loader for POS Tag datasets.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
"""
def __init__(self, data_name, data_path):
super(POSDatasetLoader, self).__init__(data_name, data_path)
@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader):
return line
def load_lines(self):
assert (os.path.exists(self.data_path))
"""
:return data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
"""
with open(self.data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return lines
return self.parse(lines)
@staticmethod
def parse(lines):
data = []
sentence = []
for line in lines:
line = line.strip()
if len(line) > 1:
sentence.append(line.split('\t'))
else:
words = []
labels = []
for tokens in sentence:
words.append(tokens[0])
labels.append(tokens[1])
data.append([words, labels])
sentence = []
if len(sentence) != 0:
words = []
labels = []
for tokens in sentence:
words.append(tokens[0])
labels.append(tokens[1])
data.append([words, labels])
return data
class ClassDatasetLoader(DatasetLoader):
@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader):
with open(self.data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
return text.strip().split()
if __name__ == "__main__":
data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines()
for example in data:
for w, l in zip(example[0], example[1]):
print(w, l)

View File

@ -28,33 +28,24 @@ class BasePreprocess(object):
class POSPreprocess(BasePreprocess):
"""
This class are used to preprocess the pos datasets.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
"""
def __init__(self, data, pickle_path, train_dev_split=0):
def __init__(self, data, pickle_path="./", train_dev_split=0):
"""
Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
:param data:
:param pickle_path:
:param data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:param pickle_path: str, the directory to the pickle files. Default: "./"
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
To do:
1. use @contextmanager to handle pickle dumps and loads
1. simplify __init__
"""
super(POSPreprocess, self).__init__(data, pickle_path)
@ -75,6 +66,7 @@ class POSPreprocess(BasePreprocess):
else:
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
_pickle.dump(self.label2index, f)
#something will be wrong if word2id.pkl is found but class2id.pkl is not found
if not self.pickle_exist("id2word.pkl"):
index2word = self.build_reverse_dict(self.word2index)
@ -98,25 +90,23 @@ class POSPreprocess(BasePreprocess):
def build_dict(self, data):
"""
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: list of list [word, label]
:return word2index: dict of (str, int)
label2index: dict of (str, int)
:param data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return word2index: dict of {str, int}
label2index: dict of {str, int}
"""
label2index = {}
word2index = DEFAULT_WORD_TO_INDEX
for line in data:
line = line.strip()
if len(line) <= 1:
continue
tokens = line.split('\t')
if tokens[0] not in word2index:
# add (word, index) into the dict
word2index[tokens[0]] = len(word2index)
# for label in tokens[1: ]:
if tokens[1] not in label2index:
label2index[tokens[1]] = len(label2index)
for example in data:
for word, label in zip(example[0], example[1]):
if word not in word2index:
word2index[word] = len(word2index)
if label not in label2index:
label2index[label] = len(label2index)
return word2index, label2index
def pickle_exist(self, pickle_name):
@ -139,24 +129,31 @@ class POSPreprocess(BasePreprocess):
def to_index(self, data):
"""
Convert word strings and label strings into indices.
:param data: list of str. Each string is a line, described above.
:return data_index: list of tuple (word index, label index)
:param data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
:return data_index: the shape of data, but each string is replaced by its corresponding index
"""
data_train = []
sentence = []
for w in data:
w = w.strip()
if len(w) <= 1:
wid = []
lid = []
for i in range(len(sentence)):
wid.append(self.word2index[sentence[i][0]])
lid.append(self.label2index[sentence[i][1]])
data_train.append((wid, lid))
sentence = []
continue
sentence.append(w.split('\t'))
return data_train
data_index = []
for example in data:
word_list = []
label_list = []
for word, label in zip(example[0], example[1]):
word_list.append(self.word2index[word])
label_list.append(self.label2index[label])
data_index.append([word_list, label_list])
return data_index
@property
def vocab_size(self):
return len(self.word2index)
@property
def num_classes(self):
return len(self.label2index)
class ClassPreprocess(BasePreprocess):

View File

@ -1,9 +1,23 @@
import unittest
from fastNLP.loader.dataset_loader import POSDatasetLoader
class MyTestCase(unittest.TestCase):
def test_something(self):
self.assertEqual(True, False)
class TestPreprocess(unittest.TestCase):
def test_case_1(self):
data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]],
["Hello", "world", "!"], ["T", "F", "F"]]
pickle_path = "./data_for_tests/"
# POSPreprocess(data, pickle_path)
class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
lines = data.split("\n")
answer = POSDatasetLoader.parse(lines)
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
self.assertListEqual(answer, truth, "POS Dataset Loader")
if __name__ == '__main__':