[update] add Loader and Pipe for AG's News dataset

This commit is contained in:
Yige Xu 2020-01-12 10:40:33 +08:00
parent 5f1d0cc4ee
commit 46ea42498d
3 changed files with 61 additions and 0 deletions

View File

@ -4,6 +4,7 @@ __all__ = [
"YelpLoader",
"YelpFullLoader",
"YelpPolarityLoader",
"AGsNewsLoader",
"IMDBLoader",
"SSTLoader",
"SST2Loader",
@ -161,6 +162,20 @@ class YelpPolarityLoader(YelpLoader):
return data_dir
class AGsNewsLoader(YelpLoader):
def download(self):
"""
自动下载数据集如果你使用了这个数据集请引用以下的文章
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)
:return: str, 数据集的目录地址
"""
return self._get_dataset_path(dataset_name='ag-news')
class IMDBLoader(Loader):
"""
原始数据中内容应该为, 每一行为一个sample制表符之前为target制表符之后为文本内容

View File

@ -57,6 +57,7 @@ class MNLILoader(Loader):
f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("RTE's test file has no target.")
warnings.warn("MNLI's test file has no target.")
for line in f:
line = line.strip()
if line:

View File

@ -21,6 +21,8 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta
from ..data_bundle import DataBundle
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \
AGsNewsLoader
from ...core._logger import logger
from ...core.const import Const
from ...core.dataset import DataSet
@ -272,6 +274,49 @@ class YelpPolarityPipe(_CLSPipe):
return self.process(data_bundle=data_bundle)
class AGsNewsPipe(YelpFullPipe):
"""
处理AG's News的数据, 处理之后DataSet中的内容如下
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field
:header: "raw_words", "target", "words", "seq_len"
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40
"...", ., "[...]", .
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::
+-------------+-----------+--------+-------+---------+
| field_names | raw_words | target | words | seq_len |
+-------------+-----------+--------+-------+---------+
| is_input | False | False | True | True |
| is_target | False | True | False | False |
| ignore_type | | False | False | False |
| pad_value | | 0 | 0 | 0 |
+-------------+-----------+--------+-------+---------+
"""
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
"""
:param bool lower: 是否对输入进行小写化
:param str tokenizer: 使用哪种tokenize方式将数据切成单词支持'spacy''raw'raw使用空格作为切分
"""
super().__init__(lower=lower, tokenizer=tokenizer)
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3}
def process_from_file(self, paths=None):
"""
:param str paths:
:return: DataBundle
"""
data_bundle = AGsNewsLoader().load(paths)
return self.process(data_bundle=data_bundle)
class SSTPipe(_CLSPipe):
"""
经过该Pipe之后DataSet中具备的field如下所示