Merge pull request #100 from choosewhatulike/dev

add dataset read functions
This commit is contained in:
Xipeng Qiu 2018-10-29 23:51:26 +08:00 committed by GitHub
commit 8b2ae2db8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 2 deletions

View File

@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
_READERS = {}
class DataSet(list):
"""A DataSet object is a list of Instance objects.
@ -135,3 +137,24 @@ class DataSet(list):
self.origin_len = (origin_field + "_origin_len", origin_field) \
if origin_len_name is None else (origin_len_name, origin_field)
return self
def __getattribute__(self, name):
if name in _READERS:
# add read_*data() support
def _read(*args, **kwargs):
data = _READERS[name]().load(*args, **kwargs)
self.extend(data)
return self
return _read
else:
return object.__getattribute__(self, name)
@classmethod
def set_reader(cls, method_name):
"""decorator to add dataloader support
"""
assert isinstance(method_name, str)
def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls
return wrapper

View File

@ -70,6 +70,7 @@ class Vocabulary(object):
else:
self.word_count[word] += 1
self.word2idx = None
return self
def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq`

View File

@ -88,6 +88,7 @@ class DataSetLoader(BaseLoader):
raise NotImplementedError
@DataSet.set_reader('read_raw')
class RawDataSetLoader(DataSetLoader):
def __init__(self):
super(RawDataSetLoader, self).__init__()
@ -103,6 +104,7 @@ class RawDataSetLoader(DataSetLoader):
return convert_seq_dataset(data)
@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.
@ -172,6 +174,7 @@ class POSDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data)
@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader):
"""
Data set loader for tokenization data sets
@ -231,6 +234,7 @@ class TokenizeDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data)
@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
@ -269,6 +273,7 @@ class ClassDataSetLoader(DataSetLoader):
return convert_seq2tag_dataset(data)
@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader):
"""loader for conll format files"""
@ -310,6 +315,7 @@ class ConllLoader(DataSetLoader):
pass
@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader
@ -346,6 +352,7 @@ class LMDataSetLoader(DataSetLoader):
pass
@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""
People Daily Corpus: Chinese word segmentation, POS tag, NER

View File

@ -3,7 +3,7 @@ import unittest
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
PeopleDailyCorpusLoader, ConllLoader
from fastNLP.core.dataset import DataSet
class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase):
def test_case_TokenizeDatasetLoader(self):
loader = TokenizeDataSetLoader()
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
filepath = "./test/data_for_tests/cws_pku_utf_8"
data = loader.load(filepath, max_seq_len=32)
assert len(data) > 0
data1 = DataSet()
data1.read_tokenize(filepath, max_seq_len=32)
assert len(data1) > 0
print("pass TokenizeDataSetLoader test!")
def test_case_POSDatasetLoader(self):
loader = POSDataSetLoader()
filepath = "./test/data_for_tests/people.txt"
data = loader.load("./test/data_for_tests/people.txt")
datas = loader.load_lines("./test/data_for_tests/people.txt")
data1 = DataSet().read_pos(filepath)
assert len(data1) > 0
print("pass POSDataSetLoader test!")
def test_case_LMDatasetLoader(self):