mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
febe27b5bb
@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader):
|
|||||||
|
|
||||||
|
|
||||||
class POSDatasetLoader(DatasetLoader):
|
class POSDatasetLoader(DatasetLoader):
|
||||||
"""loader for pos data sets"""
|
"""Dataset Loader for POS Tag datasets.
|
||||||
|
|
||||||
|
In these datasets, each line are divided by '\t'
|
||||||
|
while the first Col is the vocabulary and the second
|
||||||
|
Col is the label.
|
||||||
|
Different sentence are divided by an empty line.
|
||||||
|
e.g:
|
||||||
|
Tom label1
|
||||||
|
and label2
|
||||||
|
Jerry label1
|
||||||
|
. label3
|
||||||
|
Hello label4
|
||||||
|
world label5
|
||||||
|
! label3
|
||||||
|
In this file, there are two sentence "Tom and Jerry ."
|
||||||
|
and "Hello world !". Each word has its own label from label1
|
||||||
|
to label5.
|
||||||
|
"""
|
||||||
def __init__(self, data_name, data_path):
|
def __init__(self, data_name, data_path):
|
||||||
super(POSDatasetLoader, self).__init__(data_name, data_path)
|
super(POSDatasetLoader, self).__init__(data_name, data_path)
|
||||||
|
|
||||||
@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader):
|
|||||||
return line
|
return line
|
||||||
|
|
||||||
def load_lines(self):
|
def load_lines(self):
|
||||||
assert (os.path.exists(self.data_path))
|
"""
|
||||||
|
:return data: three-level list
|
||||||
|
[
|
||||||
|
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||||
|
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
return lines
|
return self.parse(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse(lines):
|
||||||
|
data = []
|
||||||
|
sentence = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if len(line) > 1:
|
||||||
|
sentence.append(line.split('\t'))
|
||||||
|
else:
|
||||||
|
words = []
|
||||||
|
labels = []
|
||||||
|
for tokens in sentence:
|
||||||
|
words.append(tokens[0])
|
||||||
|
labels.append(tokens[1])
|
||||||
|
data.append([words, labels])
|
||||||
|
sentence = []
|
||||||
|
if len(sentence) != 0:
|
||||||
|
words = []
|
||||||
|
labels = []
|
||||||
|
for tokens in sentence:
|
||||||
|
words.append(tokens[0])
|
||||||
|
labels.append(tokens[1])
|
||||||
|
data.append([words, labels])
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ClassDatasetLoader(DatasetLoader):
|
class ClassDatasetLoader(DatasetLoader):
|
||||||
@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader):
|
|||||||
with open(self.data_path, "r", encoding="utf=8") as f:
|
with open(self.data_path, "r", encoding="utf=8") as f:
|
||||||
text = " ".join(f.readlines())
|
text = " ".join(f.readlines())
|
||||||
return text.strip().split()
|
return text.strip().split()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines()
|
||||||
|
for example in data:
|
||||||
|
for w, l in zip(example[0], example[1]):
|
||||||
|
print(w, l)
|
||||||
|
@ -28,33 +28,24 @@ class BasePreprocess(object):
|
|||||||
class POSPreprocess(BasePreprocess):
|
class POSPreprocess(BasePreprocess):
|
||||||
"""
|
"""
|
||||||
This class are used to preprocess the pos datasets.
|
This class are used to preprocess the pos datasets.
|
||||||
In these datasets, each line are divided by '\t'
|
|
||||||
while the first Col is the vocabulary and the second
|
|
||||||
Col is the label.
|
|
||||||
Different sentence are divided by an empty line.
|
|
||||||
e.g:
|
|
||||||
Tom label1
|
|
||||||
and label2
|
|
||||||
Jerry label1
|
|
||||||
. label3
|
|
||||||
Hello label4
|
|
||||||
world label5
|
|
||||||
! label3
|
|
||||||
In this file, there are two sentence "Tom and Jerry ."
|
|
||||||
and "Hello world !". Each word has its own label from label1
|
|
||||||
to label5.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data, pickle_path, train_dev_split=0):
|
def __init__(self, data, pickle_path="./", train_dev_split=0):
|
||||||
"""
|
"""
|
||||||
Preprocess pipeline, including building mapping from words to index, from index to words,
|
Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||||
from labels/classes to index, from index to labels/classes.
|
from labels/classes to index, from index to labels/classes.
|
||||||
:param data:
|
:param data: three-level list
|
||||||
:param pickle_path:
|
[
|
||||||
|
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||||
|
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||||
|
...
|
||||||
|
]
|
||||||
|
:param pickle_path: str, the directory to the pickle files. Default: "./"
|
||||||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
|
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
|
||||||
|
|
||||||
To do:
|
To do:
|
||||||
1. use @contextmanager to handle pickle dumps and loads
|
1. simplify __init__
|
||||||
"""
|
"""
|
||||||
super(POSPreprocess, self).__init__(data, pickle_path)
|
super(POSPreprocess, self).__init__(data, pickle_path)
|
||||||
|
|
||||||
@ -75,6 +66,7 @@ class POSPreprocess(BasePreprocess):
|
|||||||
else:
|
else:
|
||||||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
|
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
|
||||||
_pickle.dump(self.label2index, f)
|
_pickle.dump(self.label2index, f)
|
||||||
|
#something will be wrong if word2id.pkl is found but class2id.pkl is not found
|
||||||
|
|
||||||
if not self.pickle_exist("id2word.pkl"):
|
if not self.pickle_exist("id2word.pkl"):
|
||||||
index2word = self.build_reverse_dict(self.word2index)
|
index2word = self.build_reverse_dict(self.word2index)
|
||||||
@ -98,25 +90,23 @@ class POSPreprocess(BasePreprocess):
|
|||||||
def build_dict(self, data):
|
def build_dict(self, data):
|
||||||
"""
|
"""
|
||||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
||||||
:param data: list of list [word, label]
|
:param data: three-level list
|
||||||
:return word2index: dict of (str, int)
|
[
|
||||||
label2index: dict of (str, int)
|
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||||
|
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||||
|
...
|
||||||
|
]
|
||||||
|
:return word2index: dict of {str, int}
|
||||||
|
label2index: dict of {str, int}
|
||||||
"""
|
"""
|
||||||
label2index = {}
|
label2index = {}
|
||||||
word2index = DEFAULT_WORD_TO_INDEX
|
word2index = DEFAULT_WORD_TO_INDEX
|
||||||
for line in data:
|
for example in data:
|
||||||
line = line.strip()
|
for word, label in zip(example[0], example[1]):
|
||||||
if len(line) <= 1:
|
if word not in word2index:
|
||||||
continue
|
word2index[word] = len(word2index)
|
||||||
tokens = line.split('\t')
|
if label not in label2index:
|
||||||
|
label2index[label] = len(label2index)
|
||||||
if tokens[0] not in word2index:
|
|
||||||
# add (word, index) into the dict
|
|
||||||
word2index[tokens[0]] = len(word2index)
|
|
||||||
|
|
||||||
# for label in tokens[1: ]:
|
|
||||||
if tokens[1] not in label2index:
|
|
||||||
label2index[tokens[1]] = len(label2index)
|
|
||||||
return word2index, label2index
|
return word2index, label2index
|
||||||
|
|
||||||
def pickle_exist(self, pickle_name):
|
def pickle_exist(self, pickle_name):
|
||||||
@ -139,24 +129,31 @@ class POSPreprocess(BasePreprocess):
|
|||||||
def to_index(self, data):
|
def to_index(self, data):
|
||||||
"""
|
"""
|
||||||
Convert word strings and label strings into indices.
|
Convert word strings and label strings into indices.
|
||||||
:param data: list of str. Each string is a line, described above.
|
:param data: three-level list
|
||||||
:return data_index: list of tuple (word index, label index)
|
[
|
||||||
|
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||||
|
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||||
|
...
|
||||||
|
]
|
||||||
|
:return data_index: the shape of data, but each string is replaced by its corresponding index
|
||||||
"""
|
"""
|
||||||
data_train = []
|
data_index = []
|
||||||
sentence = []
|
for example in data:
|
||||||
for w in data:
|
word_list = []
|
||||||
w = w.strip()
|
label_list = []
|
||||||
if len(w) <= 1:
|
for word, label in zip(example[0], example[1]):
|
||||||
wid = []
|
word_list.append(self.word2index[word])
|
||||||
lid = []
|
label_list.append(self.label2index[label])
|
||||||
for i in range(len(sentence)):
|
data_index.append([word_list, label_list])
|
||||||
wid.append(self.word2index[sentence[i][0]])
|
return data_index
|
||||||
lid.append(self.label2index[sentence[i][1]])
|
|
||||||
data_train.append((wid, lid))
|
@property
|
||||||
sentence = []
|
def vocab_size(self):
|
||||||
continue
|
return len(self.word2index)
|
||||||
sentence.append(w.split('\t'))
|
|
||||||
return data_train
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self.label2index)
|
||||||
|
|
||||||
|
|
||||||
class ClassPreprocess(BasePreprocess):
|
class ClassPreprocess(BasePreprocess):
|
||||||
|
@ -1,9 +1,23 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||||
|
|
||||||
class MyTestCase(unittest.TestCase):
|
|
||||||
def test_something(self):
|
class TestPreprocess(unittest.TestCase):
|
||||||
self.assertEqual(True, False)
|
def test_case_1(self):
|
||||||
|
data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]],
|
||||||
|
["Hello", "world", "!"], ["T", "F", "F"]]
|
||||||
|
pickle_path = "./data_for_tests/"
|
||||||
|
# POSPreprocess(data, pickle_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetLoader(unittest.TestCase):
|
||||||
|
def test_case_1(self):
|
||||||
|
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
|
||||||
|
lines = data.split("\n")
|
||||||
|
answer = POSDatasetLoader.parse(lines)
|
||||||
|
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
|
||||||
|
self.assertListEqual(answer, truth, "POS Dataset Loader")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user