mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
add dataset read functions
This commit is contained in:
parent
6a1d237c64
commit
ebbfcb7829
@ -7,6 +7,8 @@ from fastNLP.core.field import TextField, LabelField
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
_READERS = {}
|
||||
|
||||
class DataSet(list):
|
||||
"""A DataSet object is a list of Instance objects.
|
||||
|
||||
@ -125,3 +127,24 @@ class DataSet(list):
|
||||
self.origin_len = (origin_field + "_origin_len", origin_field) \
|
||||
if origin_len_name is None else (origin_len_name, origin_field)
|
||||
return self
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in _READERS:
|
||||
# add read_*data() support
|
||||
def _read(*args, **kwargs):
|
||||
data = _READERS[name]().load(*args, **kwargs)
|
||||
self.extend(data)
|
||||
return self
|
||||
return _read
|
||||
else:
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
@classmethod
|
||||
def set_reader(cls, method_name):
|
||||
"""decorator to add dataloader support
|
||||
"""
|
||||
assert isinstance(method_name, str)
|
||||
def wrapper(read_cls):
|
||||
_READERS[method_name] = read_cls
|
||||
return read_cls
|
||||
return wrapper
|
||||
|
@ -69,6 +69,7 @@ class Vocabulary(object):
|
||||
else:
|
||||
self.word_count[word] += 1
|
||||
self.word2idx = None
|
||||
return self
|
||||
|
||||
|
||||
def build_vocab(self):
|
||||
|
@ -84,6 +84,7 @@ class DataSetLoader(BaseLoader):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DataSet.set_reader('read_raw')
|
||||
class RawDataSetLoader(DataSetLoader):
|
||||
def __init__(self):
|
||||
super(RawDataSetLoader, self).__init__()
|
||||
@ -98,6 +99,7 @@ class RawDataSetLoader(DataSetLoader):
|
||||
def convert(self, data):
|
||||
return convert_seq_dataset(data)
|
||||
|
||||
@DataSet.set_reader('read_pos')
|
||||
class POSDataSetLoader(DataSetLoader):
|
||||
"""Dataset Loader for POS Tag datasets.
|
||||
|
||||
@ -166,6 +168,7 @@ class POSDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
@DataSet.set_reader('read_tokenize')
|
||||
class TokenizeDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
Data set loader for tokenization data sets
|
||||
@ -224,7 +227,7 @@ class TokenizeDataSetLoader(DataSetLoader):
|
||||
def convert(self, data):
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
|
||||
@DataSet.set_reader('read_class')
|
||||
class ClassDataSetLoader(DataSetLoader):
|
||||
"""Loader for classification data sets"""
|
||||
|
||||
@ -262,7 +265,7 @@ class ClassDataSetLoader(DataSetLoader):
|
||||
def convert(self, data):
|
||||
return convert_seq2tag_dataset(data)
|
||||
|
||||
|
||||
@DataSet.set_reader('read_conll')
|
||||
class ConllLoader(DataSetLoader):
|
||||
"""loader for conll format files"""
|
||||
|
||||
@ -303,7 +306,7 @@ class ConllLoader(DataSetLoader):
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
|
||||
@DataSet.set_reader('read_lm')
|
||||
class LMDataSetLoader(DataSetLoader):
|
||||
"""Language Model Dataset Loader
|
||||
|
||||
@ -339,6 +342,7 @@ class LMDataSetLoader(DataSetLoader):
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
@DataSet.set_reader('read_people_daily')
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""
|
||||
People Daily Corpus: Chinese word segmentation, POS tag, NER
|
||||
|
@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
|
||||
PeopleDailyCorpusLoader, ConllLoader
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
class TestDatasetLoader(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
@ -15,13 +15,23 @@ class TestDatasetLoader(unittest.TestCase):
|
||||
|
||||
def test_case_TokenizeDatasetLoader(self):
|
||||
loader = TokenizeDataSetLoader()
|
||||
data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32)
|
||||
filepath = "./test/data_for_tests/cws_pku_utf_8"
|
||||
data = loader.load(filepath, max_seq_len=32)
|
||||
assert len(data) > 0
|
||||
|
||||
data1 = DataSet()
|
||||
data1.read_tokenize(filepath, max_seq_len=32)
|
||||
assert len(data1) > 0
|
||||
print("pass TokenizeDataSetLoader test!")
|
||||
|
||||
def test_case_POSDatasetLoader(self):
|
||||
loader = POSDataSetLoader()
|
||||
filepath = "./test/data_for_tests/people.txt"
|
||||
data = loader.load("./test/data_for_tests/people.txt")
|
||||
datas = loader.load_lines("./test/data_for_tests/people.txt")
|
||||
|
||||
data1 = DataSet().read_pos(filepath)
|
||||
assert len(data1) > 0
|
||||
print("pass POSDataSetLoader test!")
|
||||
|
||||
def test_case_LMDatasetLoader(self):
|
||||
|
Loading…
Reference in New Issue
Block a user