mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
整理所有dataset loader,建立单元测试
This commit is contained in:
parent
bfaf09df8c
commit
986541139a
@ -11,18 +11,24 @@ class BaseLoader(object):
|
||||
|
||||
@staticmethod
|
||||
def load_lines(data_path):
|
||||
"""按行读取,舍弃每行两侧空白字符,返回list of str
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = f.readlines()
|
||||
return [line.strip() for line in text]
|
||||
|
||||
@classmethod
|
||||
def load(cls, data_path):
|
||||
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
text = f.readlines()
|
||||
return [[word for word in sent.strip()] for sent in text]
|
||||
|
||||
@classmethod
|
||||
def load_with_cache(cls, data_path, cache_path):
|
||||
"""缓存版的load
|
||||
"""
|
||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path):
|
||||
with open(cache_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
@ -11,7 +11,6 @@ class ConfigLoader(BaseLoader):
|
||||
:param str data_path: path to the config
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_path=None):
|
||||
super(ConfigLoader, self).__init__()
|
||||
if data_path is not None:
|
||||
@ -30,7 +29,7 @@ class ConfigLoader(BaseLoader):
|
||||
Example::
|
||||
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
|
||||
"""
|
||||
assert isinstance(sections, dict)
|
||||
@ -202,8 +201,6 @@ class ConfigSaver(object):
|
||||
continue
|
||||
|
||||
if '=' not in line:
|
||||
# log = create_logger(__name__, './config_saver.log')
|
||||
# log.error("can NOT load config file [%s]" % self.file_path)
|
||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path))
|
||||
|
||||
key = line.split('=', maxsplit=1)[0].strip()
|
||||
@ -263,10 +260,6 @@ class ConfigSaver(object):
|
||||
change_file = True
|
||||
break
|
||||
if section_file[k] != section[k]:
|
||||
# logger = create_logger(__name__, "./config_loader.log")
|
||||
# logger.warning("section [%s] in config file [%s] has been changed" % (
|
||||
# section_name, self.file_path
|
||||
# ))
|
||||
change_file = True
|
||||
break
|
||||
if not change_file:
|
||||
|
@ -126,8 +126,8 @@ class RawDataSetLoader(DataSetLoader):
|
||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')
|
||||
|
||||
|
||||
class POSDataSetLoader(DataSetLoader):
|
||||
"""Dataset Loader for a POS Tag dataset.
|
||||
class DummyPOSReader(DataSetLoader):
|
||||
"""A simple reader for a dummy POS tagging dataset.
|
||||
|
||||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second
|
||||
Col is the label. Different sentence are divided by an empty line.
|
||||
@ -146,7 +146,7 @@ class POSDataSetLoader(DataSetLoader):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(POSDataSetLoader, self).__init__()
|
||||
super(DummyPOSReader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
"""
|
||||
@ -194,16 +194,14 @@ class POSDataSetLoader(DataSetLoader):
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
|
||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')
|
||||
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos')
|
||||
|
||||
|
||||
class TokenizeDataSetLoader(DataSetLoader):
|
||||
class DummyCWSReader(DataSetLoader):
|
||||
"""Load pku dataset for Chinese word segmentation.
|
||||
"""
|
||||
Data set loader for tokenization data sets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(TokenizeDataSetLoader, self).__init__()
|
||||
super(DummyCWSReader, self).__init__()
|
||||
|
||||
def load(self, data_path, max_seq_len=32):
|
||||
"""Load pku dataset for Chinese word segmentation.
|
||||
@ -256,11 +254,11 @@ class TokenizeDataSetLoader(DataSetLoader):
|
||||
return convert_seq2seq_dataset(data)
|
||||
|
||||
|
||||
class ClassDataSetLoader(DataSetLoader):
|
||||
class DummyClassificationReader(DataSetLoader):
|
||||
"""Loader for a dummy classification data set"""
|
||||
|
||||
def __init__(self):
|
||||
super(ClassDataSetLoader, self).__init__()
|
||||
super(DummyClassificationReader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
assert os.path.exists(data_path)
|
||||
@ -271,7 +269,7 @@ class ClassDataSetLoader(DataSetLoader):
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
"""
|
||||
"""每行第一个token是标签,其余是字/词;由空格分隔。
|
||||
|
||||
:param lines: lines from dataset
|
||||
:return: list(list(list())): the three level of lists are words, sentence, and dataset
|
||||
@ -327,16 +325,11 @@ class ConllLoader(DataSetLoader):
|
||||
pass
|
||||
|
||||
|
||||
class LMDataSetLoader(DataSetLoader):
|
||||
"""Language Model Dataset Loader
|
||||
|
||||
This loader produces data for language model training in a supervised way.
|
||||
That means it has X and Y.
|
||||
|
||||
class DummyLMReader(DataSetLoader):
|
||||
"""A Dummy Language Model Dataset Reader
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(LMDataSetLoader, self).__init__()
|
||||
super(DummyLMReader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
if not os.path.exists(data_path):
|
||||
@ -364,19 +357,25 @@ class LMDataSetLoader(DataSetLoader):
|
||||
|
||||
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""人民日报数据集
|
||||
"""
|
||||
People Daily Corpus: Chinese word segmentation, POS tag, NER
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(PeopleDailyCorpusLoader, self).__init__()
|
||||
self.pos = True
|
||||
self.ner = True
|
||||
|
||||
def load(self, data_path):
|
||||
def load(self, data_path, pos=True, ner=True):
|
||||
"""
|
||||
|
||||
:param str data_path: 数据路径
|
||||
:param bool pos: 是否使用词性标签
|
||||
:param bool ner: 是否使用命名实体标签
|
||||
:return: a DataSet object
|
||||
"""
|
||||
self.pos, self.ner = pos, ner
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
sents = f.readlines()
|
||||
|
||||
pos_tag_examples = []
|
||||
ner_examples = []
|
||||
examples = []
|
||||
for sent in sents:
|
||||
if len(sent) <= 2:
|
||||
continue
|
||||
@ -410,40 +409,44 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
sent_ner.append(ner_tag)
|
||||
sent_pos_tag.append(pos)
|
||||
sent_words.append(token)
|
||||
pos_tag_examples.append([sent_words, sent_pos_tag])
|
||||
ner_examples.append([sent_words, sent_ner])
|
||||
# List[List[List[str], List[str]]]
|
||||
# ner_examples not used
|
||||
return self.convert(pos_tag_examples)
|
||||
example = [sent_words]
|
||||
if self.pos is True:
|
||||
example.append(sent_pos_tag)
|
||||
if self.ner is True:
|
||||
example.append(sent_ner)
|
||||
examples.append(example)
|
||||
return self.convert(examples)
|
||||
|
||||
def convert(self, data):
|
||||
data_set = DataSet()
|
||||
for item in data:
|
||||
sent_words, sent_pos_tag = item[0], item[1]
|
||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
|
||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
|
||||
data_set.set_target("tags")
|
||||
data_set.set_input("sent_words")
|
||||
data_set.set_input("seq_len")
|
||||
sent_words = item[0]
|
||||
if self.pos is True and self.ner is True:
|
||||
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2])
|
||||
elif self.pos is True:
|
||||
instance = Instance(words=sent_words, pos_tags=item[1])
|
||||
elif self.ner is True:
|
||||
instance = Instance(words=sent_words, ner=item[1])
|
||||
else:
|
||||
instance = Instance(words=sent_words)
|
||||
data_set.append(instance)
|
||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
|
||||
return data_set
|
||||
|
||||
|
||||
class Conll2003Loader(DataSetLoader):
|
||||
"""Self-defined loader of conll2003 dataset
|
||||
"""Loader for conll2003 dataset
|
||||
|
||||
More information about the given dataset cound be found on
|
||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Conll2003Loader, self).__init__()
|
||||
|
||||
def load(self, dataset_path):
|
||||
with open(dataset_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
##Parse the dataset line by line
|
||||
parsed_data = []
|
||||
sentence = []
|
||||
tokens = []
|
||||
@ -470,21 +473,20 @@ class Conll2003Loader(DataSetLoader):
|
||||
lambda labels: labels[1], sample[1]))
|
||||
label2_list = list(map(
|
||||
lambda labels: labels[2], sample[1]))
|
||||
dataset.append(Instance(token_list=sample[0],
|
||||
label0_list=label0_list,
|
||||
label1_list=label1_list,
|
||||
label2_list=label2_list))
|
||||
dataset.append(Instance(tokens=sample[0],
|
||||
pos=label0_list,
|
||||
chucks=label1_list,
|
||||
ner=label2_list))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class SNLIDataSetLoader(DataSetLoader):
|
||||
class SNLIDataSetReader(DataSetLoader):
|
||||
"""A data set loader for SNLI data set.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SNLIDataSetLoader, self).__init__()
|
||||
super(SNLIDataSetReader, self).__init__()
|
||||
|
||||
def load(self, path_list):
|
||||
"""
|
||||
@ -553,6 +555,8 @@ class ConllCWSReader(object):
|
||||
"""
|
||||
返回的DataSet只包含raw_sentence这个field,内容为str。
|
||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
|
||||
::
|
||||
|
||||
1 编者按 编者按 NN O 11 nmod:topic
|
||||
2 : : PU O 11 punct
|
||||
3 7月 7月 NT DATE 4 compound:nn
|
||||
@ -564,6 +568,7 @@ class ConllCWSReader(object):
|
||||
3 飞行 飞行 NN O 8 nsubj
|
||||
4 从 从 P O 5 case
|
||||
5 外型 外型 NN O 8 nmod:prep
|
||||
|
||||
"""
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
@ -575,7 +580,7 @@ class ConllCWSReader(object):
|
||||
elif line.startswith('#'):
|
||||
continue
|
||||
else:
|
||||
sample.append(line.split('\t'))
|
||||
sample.append(line.strip().split())
|
||||
if len(sample) > 0:
|
||||
datalist.append(sample)
|
||||
|
||||
@ -592,7 +597,6 @@ class ConllCWSReader(object):
|
||||
sents = [line]
|
||||
for raw_sentence in sents:
|
||||
ds.append(Instance(raw_sentence=raw_sentence))
|
||||
|
||||
return ds
|
||||
|
||||
def get_char_lst(self, sample):
|
||||
@ -607,70 +611,22 @@ class ConllCWSReader(object):
|
||||
return text
|
||||
|
||||
|
||||
class POSCWSReader(DataSetLoader):
|
||||
"""
|
||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限.
|
||||
迈 N
|
||||
向 N
|
||||
充 N
|
||||
...
|
||||
泽 I-PER
|
||||
民 I-PER
|
||||
|
||||
( N
|
||||
一 N
|
||||
九 N
|
||||
...
|
||||
|
||||
|
||||
:param filepath:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(self, in_word_splitter=None):
|
||||
super().__init__()
|
||||
self.in_word_splitter = in_word_splitter
|
||||
|
||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
|
||||
if in_word_splitter is None:
|
||||
in_word_splitter = self.in_word_splitter
|
||||
dataset = DataSet()
|
||||
with open(filepath, 'r') as f:
|
||||
words = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0: # new line
|
||||
if len(words) == 0: # 不能接受空行
|
||||
continue
|
||||
line = ' '.join(words)
|
||||
if cut_long_sent:
|
||||
sents = cut_long_sentence(line)
|
||||
else:
|
||||
sents = [line]
|
||||
for sent in sents:
|
||||
instance = Instance(raw_sentence=sent)
|
||||
dataset.append(instance)
|
||||
words = []
|
||||
else:
|
||||
line = line.split()[0]
|
||||
if in_word_splitter is None:
|
||||
words.append(line)
|
||||
else:
|
||||
words.append(line.split(in_word_splitter)[0])
|
||||
return dataset
|
||||
|
||||
|
||||
class NaiveCWSReader(DataSetLoader):
|
||||
"""
|
||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了
|
||||
例如::
|
||||
|
||||
这是 fastNLP , 一个 非常 good 的 包 .
|
||||
|
||||
或者,即每个part后面还有一个pos tag
|
||||
例如::
|
||||
|
||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_word_splitter=None):
|
||||
super().__init__()
|
||||
|
||||
super(NaiveCWSReader, self).__init__()
|
||||
self.in_word_splitter = in_word_splitter
|
||||
|
||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
|
||||
@ -680,8 +636,10 @@ class NaiveCWSReader(DataSetLoader):
|
||||
和
|
||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
|
||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]
|
||||
|
||||
:param filepath:
|
||||
:param in_word_splitter:
|
||||
:param cut_long_sent:
|
||||
:return:
|
||||
"""
|
||||
if in_word_splitter == None:
|
||||
@ -740,7 +698,9 @@ def cut_long_sentence(sent, max_sample_length=200):
|
||||
|
||||
|
||||
class ZhConllPOSReader(object):
|
||||
# 中文colln格式reader
|
||||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@ -750,6 +710,8 @@ class ZhConllPOSReader(object):
|
||||
words:list of str,
|
||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
|
||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
|
||||
::
|
||||
|
||||
1 编者按 编者按 NN O 11 nmod:topic
|
||||
2 : : PU O 11 punct
|
||||
3 7月 7月 NT DATE 4 compound:nn
|
||||
@ -761,6 +723,7 @@ class ZhConllPOSReader(object):
|
||||
3 飞行 飞行 NN O 8 nsubj
|
||||
4 从 从 P O 5 case
|
||||
5 外型 外型 NN O 8 nmod:prep
|
||||
|
||||
"""
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
@ -815,67 +778,10 @@ class ZhConllPOSReader(object):
|
||||
return text, pos_tags
|
||||
|
||||
|
||||
class ConllPOSReader(object):
|
||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load(self, path):
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
sample = []
|
||||
for line in f:
|
||||
if line.startswith('\n'):
|
||||
datalist.append(sample)
|
||||
sample = []
|
||||
elif line.startswith('#'):
|
||||
continue
|
||||
else:
|
||||
sample.append(line.split('\t'))
|
||||
if len(sample) > 0:
|
||||
datalist.append(sample)
|
||||
|
||||
ds = DataSet()
|
||||
for sample in datalist:
|
||||
# print(sample)
|
||||
res = self.get_one(sample)
|
||||
if res is None:
|
||||
continue
|
||||
char_seq = []
|
||||
pos_seq = []
|
||||
for word, tag in zip(res[0], res[1]):
|
||||
if len(word) == 1:
|
||||
char_seq.append(word)
|
||||
pos_seq.append('S-{}'.format(tag))
|
||||
elif len(word) > 1:
|
||||
pos_seq.append('B-{}'.format(tag))
|
||||
for _ in range(len(word) - 2):
|
||||
pos_seq.append('M-{}'.format(tag))
|
||||
pos_seq.append('E-{}'.format(tag))
|
||||
char_seq.extend(list(word))
|
||||
else:
|
||||
raise ValueError("Zero length of word detected.")
|
||||
|
||||
ds.append(Instance(words=char_seq,
|
||||
tag=pos_seq))
|
||||
return ds
|
||||
|
||||
def get_one(self, sample):
|
||||
if len(sample) == 0:
|
||||
return None
|
||||
text = []
|
||||
pos_tags = []
|
||||
for w in sample:
|
||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
|
||||
if t3 == '_':
|
||||
return None
|
||||
text.append(t1)
|
||||
pos_tags.append(t2)
|
||||
return text, pos_tags
|
||||
|
||||
|
||||
|
||||
class ConllxDataLoader(object):
|
||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。
|
||||
|
||||
"""
|
||||
def load(self, path):
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
|
@ -1,18 +1,20 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM
|
||||
from fastNLP.modules.encoder.transformer import TransformerEncoder
|
||||
from fastNLP.modules.dropout import TimestepDropout
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.utils import seq_mask
|
||||
|
||||
from fastNLP.core.losses import LossFunc
|
||||
from fastNLP.core.metrics import MetricBase
|
||||
from fastNLP.core.utils import seq_lens_to_masks
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.dropout import TimestepDropout
|
||||
from fastNLP.modules.encoder.transformer import TransformerEncoder
|
||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from fastNLP.modules.utils import seq_mask
|
||||
|
||||
|
||||
def mst(scores):
|
||||
"""
|
||||
|
@ -4,7 +4,7 @@ from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.core.utils import ClassPreprocess as Preprocess
|
||||
from fastNLP.io.config_io import ConfigLoader
|
||||
from fastNLP.io.config_io import ConfigSection
|
||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
|
||||
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.aggregator.self_attention import SelfAttention
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
|
@ -138,6 +138,7 @@ class TestCase1(unittest.TestCase):
|
||||
for batch_x, batch_y in batch:
|
||||
time.sleep(pause_seconds)
|
||||
|
||||
"""
|
||||
def test_multi_workers_batch(self):
|
||||
batch_size = 32
|
||||
pause_seconds = 0.01
|
||||
@ -154,7 +155,8 @@ class TestCase1(unittest.TestCase):
|
||||
end1 = time.time()
|
||||
for batch_x, batch_y in batch:
|
||||
time.sleep(pause_seconds)
|
||||
|
||||
"""
|
||||
"""
|
||||
def test_pin_memory(self):
|
||||
batch_size = 32
|
||||
pause_seconds = 0.01
|
||||
@ -172,3 +174,4 @@ class TestCase1(unittest.TestCase):
|
||||
# 这里发生OOM
|
||||
# for batch_x, batch_y in batch:
|
||||
# time.sleep(pause_seconds)
|
||||
"""
|
||||
|
@ -237,6 +237,7 @@ class TrainerTestGround(unittest.TestCase):
|
||||
use_tqdm=False,
|
||||
print_every=2)
|
||||
|
||||
"""
|
||||
def test_trainer_multiprocess(self):
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
@ -264,4 +265,4 @@ class TrainerTestGround(unittest.TestCase):
|
||||
timeout=0,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
"""
|
||||
|
@ -1,24 +1,27 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.io.dataset_loader import Conll2003Loader
|
||||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \
|
||||
ZhConllPOSReader, ConllxDataLoader
|
||||
|
||||
|
||||
class TestDatasetLoader(unittest.TestCase):
|
||||
|
||||
def test_case_1(self):
|
||||
'''
|
||||
def test_Conll2003Loader(self):
|
||||
"""
|
||||
Test the the loader of Conll2003 dataset
|
||||
'''
|
||||
|
||||
"""
|
||||
dataset_path = "test/data_for_tests/conll_2003_example.txt"
|
||||
loader = Conll2003Loader()
|
||||
dataset_2003 = loader.load(dataset_path)
|
||||
|
||||
for item in dataset_2003:
|
||||
len0 = len(item["label0_list"])
|
||||
len1 = len(item["label1_list"])
|
||||
len2 = len(item["label2_list"])
|
||||
lentoken = len(item["token_list"])
|
||||
self.assertNotEqual(len0, 0)
|
||||
self.assertEqual(len0, len1)
|
||||
self.assertEqual(len1, len2)
|
||||
def test_PeopleDailyCorpusLoader(self):
|
||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt")
|
||||
|
||||
def test_ConllCWSReader(self):
|
||||
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt")
|
||||
|
||||
def test_ZhConllPOSReader(self):
|
||||
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx")
|
||||
|
||||
def test_ConllxDataLoader(self):
|
||||
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx")
|
||||
|
Loading…
Reference in New Issue
Block a user