mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
febe27b5bb
@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader):
|
||||
|
||||
|
||||
class POSDatasetLoader(DatasetLoader):
|
||||
"""loader for pos data sets"""
|
||||
"""Dataset Loader for POS Tag datasets.
|
||||
|
||||
In these datasets, each line are divided by '\t'
|
||||
while the first Col is the vocabulary and the second
|
||||
Col is the label.
|
||||
Different sentence are divided by an empty line.
|
||||
e.g:
|
||||
Tom label1
|
||||
and label2
|
||||
Jerry label1
|
||||
. label3
|
||||
Hello label4
|
||||
world label5
|
||||
! label3
|
||||
In this file, there are two sentence "Tom and Jerry ."
|
||||
and "Hello world !". Each word has its own label from label1
|
||||
to label5.
|
||||
"""
|
||||
def __init__(self, data_name, data_path):
|
||||
super(POSDatasetLoader, self).__init__(data_name, data_path)
|
||||
|
||||
@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader):
|
||||
return line
|
||||
|
||||
def load_lines(self):
|
||||
assert (os.path.exists(self.data_path))
|
||||
"""
|
||||
:return data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
"""
|
||||
with open(self.data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return lines
|
||||
return self.parse(lines)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
data = []
|
||||
sentence = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if len(line) > 1:
|
||||
sentence.append(line.split('\t'))
|
||||
else:
|
||||
words = []
|
||||
labels = []
|
||||
for tokens in sentence:
|
||||
words.append(tokens[0])
|
||||
labels.append(tokens[1])
|
||||
data.append([words, labels])
|
||||
sentence = []
|
||||
if len(sentence) != 0:
|
||||
words = []
|
||||
labels = []
|
||||
for tokens in sentence:
|
||||
words.append(tokens[0])
|
||||
labels.append(tokens[1])
|
||||
data.append([words, labels])
|
||||
return data
|
||||
|
||||
|
||||
class ClassDatasetLoader(DatasetLoader):
|
||||
@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader):
|
||||
with open(self.data_path, "r", encoding="utf=8") as f:
|
||||
text = " ".join(f.readlines())
|
||||
return text.strip().split()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines()
|
||||
for example in data:
|
||||
for w, l in zip(example[0], example[1]):
|
||||
print(w, l)
|
||||
|
@ -28,33 +28,24 @@ class BasePreprocess(object):
|
||||
class POSPreprocess(BasePreprocess):
|
||||
"""
|
||||
This class are used to preprocess the pos datasets.
|
||||
In these datasets, each line are divided by '\t'
|
||||
while the first Col is the vocabulary and the second
|
||||
Col is the label.
|
||||
Different sentence are divided by an empty line.
|
||||
e.g:
|
||||
Tom label1
|
||||
and label2
|
||||
Jerry label1
|
||||
. label3
|
||||
Hello label4
|
||||
world label5
|
||||
! label3
|
||||
In this file, there are two sentence "Tom and Jerry ."
|
||||
and "Hello world !". Each word has its own label from label1
|
||||
to label5.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, pickle_path, train_dev_split=0):
|
||||
def __init__(self, data, pickle_path="./", train_dev_split=0):
|
||||
"""
|
||||
Preprocess pipeline, including building mapping from words to index, from index to words,
|
||||
from labels/classes to index, from index to labels/classes.
|
||||
:param data:
|
||||
:param pickle_path:
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:param pickle_path: str, the directory to the pickle files. Default: "./"
|
||||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
|
||||
|
||||
To do:
|
||||
1. use @contextmanager to handle pickle dumps and loads
|
||||
1. simplify __init__
|
||||
"""
|
||||
super(POSPreprocess, self).__init__(data, pickle_path)
|
||||
|
||||
@ -75,6 +66,7 @@ class POSPreprocess(BasePreprocess):
|
||||
else:
|
||||
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
|
||||
_pickle.dump(self.label2index, f)
|
||||
#something will be wrong if word2id.pkl is found but class2id.pkl is not found
|
||||
|
||||
if not self.pickle_exist("id2word.pkl"):
|
||||
index2word = self.build_reverse_dict(self.word2index)
|
||||
@ -98,25 +90,23 @@ class POSPreprocess(BasePreprocess):
|
||||
def build_dict(self, data):
|
||||
"""
|
||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
||||
:param data: list of list [word, label]
|
||||
:return word2index: dict of (str, int)
|
||||
label2index: dict of (str, int)
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return word2index: dict of {str, int}
|
||||
label2index: dict of {str, int}
|
||||
"""
|
||||
label2index = {}
|
||||
word2index = DEFAULT_WORD_TO_INDEX
|
||||
for line in data:
|
||||
line = line.strip()
|
||||
if len(line) <= 1:
|
||||
continue
|
||||
tokens = line.split('\t')
|
||||
|
||||
if tokens[0] not in word2index:
|
||||
# add (word, index) into the dict
|
||||
word2index[tokens[0]] = len(word2index)
|
||||
|
||||
# for label in tokens[1: ]:
|
||||
if tokens[1] not in label2index:
|
||||
label2index[tokens[1]] = len(label2index)
|
||||
for example in data:
|
||||
for word, label in zip(example[0], example[1]):
|
||||
if word not in word2index:
|
||||
word2index[word] = len(word2index)
|
||||
if label not in label2index:
|
||||
label2index[label] = len(label2index)
|
||||
return word2index, label2index
|
||||
|
||||
def pickle_exist(self, pickle_name):
|
||||
@ -139,24 +129,31 @@ class POSPreprocess(BasePreprocess):
|
||||
def to_index(self, data):
|
||||
"""
|
||||
Convert word strings and label strings into indices.
|
||||
:param data: list of str. Each string is a line, described above.
|
||||
:return data_index: list of tuple (word index, label index)
|
||||
:param data: three-level list
|
||||
[
|
||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
|
||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
|
||||
...
|
||||
]
|
||||
:return data_index: the shape of data, but each string is replaced by its corresponding index
|
||||
"""
|
||||
data_train = []
|
||||
sentence = []
|
||||
for w in data:
|
||||
w = w.strip()
|
||||
if len(w) <= 1:
|
||||
wid = []
|
||||
lid = []
|
||||
for i in range(len(sentence)):
|
||||
wid.append(self.word2index[sentence[i][0]])
|
||||
lid.append(self.label2index[sentence[i][1]])
|
||||
data_train.append((wid, lid))
|
||||
sentence = []
|
||||
continue
|
||||
sentence.append(w.split('\t'))
|
||||
return data_train
|
||||
data_index = []
|
||||
for example in data:
|
||||
word_list = []
|
||||
label_list = []
|
||||
for word, label in zip(example[0], example[1]):
|
||||
word_list.append(self.word2index[word])
|
||||
label_list.append(self.label2index[label])
|
||||
data_index.append([word_list, label_list])
|
||||
return data_index
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.word2index)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.label2index)
|
||||
|
||||
|
||||
class ClassPreprocess(BasePreprocess):
|
||||
|
@ -1,9 +1,23 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.loader.dataset_loader import POSDatasetLoader
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_something(self):
|
||||
self.assertEqual(True, False)
|
||||
|
||||
class TestPreprocess(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]],
|
||||
["Hello", "world", "!"], ["T", "F", "F"]]
|
||||
pickle_path = "./data_for_tests/"
|
||||
# POSPreprocess(data, pickle_path)
|
||||
|
||||
|
||||
class TestDatasetLoader(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
|
||||
lines = data.split("\n")
|
||||
answer = POSDatasetLoader.parse(lines)
|
||||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
|
||||
self.assertListEqual(answer, truth, "POS Dataset Loader")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user