mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
[update] add Loader and Pipe for AG's News dataset
This commit is contained in:
parent
5f1d0cc4ee
commit
46ea42498d
@ -4,6 +4,7 @@ __all__ = [
|
||||
"YelpLoader",
|
||||
"YelpFullLoader",
|
||||
"YelpPolarityLoader",
|
||||
"AGsNewsLoader",
|
||||
"IMDBLoader",
|
||||
"SSTLoader",
|
||||
"SST2Loader",
|
||||
@ -161,6 +162,20 @@ class YelpPolarityLoader(YelpLoader):
|
||||
return data_dir
|
||||
|
||||
|
||||
class AGsNewsLoader(YelpLoader):
|
||||
def download(self):
|
||||
"""
|
||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章
|
||||
|
||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
|
||||
in Neural Information Processing Systems 28 (NIPS 2015)
|
||||
|
||||
:return: str, 数据集的目录地址
|
||||
"""
|
||||
|
||||
return self._get_dataset_path(dataset_name='ag-news')
|
||||
|
||||
|
||||
class IMDBLoader(Loader):
|
||||
"""
|
||||
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。
|
||||
|
@ -57,6 +57,7 @@ class MNLILoader(Loader):
|
||||
f.readline() # 跳过header
|
||||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
|
||||
warnings.warn("RTE's test file has no target.")
|
||||
warnings.warn("MNLI's test file has no target.")
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
|
@ -21,6 +21,8 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta
|
||||
from ..data_bundle import DataBundle
|
||||
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
|
||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
|
||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \
|
||||
AGsNewsLoader
|
||||
from ...core._logger import logger
|
||||
from ...core.const import Const
|
||||
from ...core.dataset import DataSet
|
||||
@ -272,6 +274,49 @@ class YelpPolarityPipe(_CLSPipe):
|
||||
return self.process(data_bundle=data_bundle)
|
||||
|
||||
|
||||
class AGsNewsPipe(YelpFullPipe):
|
||||
"""
|
||||
处理AG's News的数据, 处理之后DataSet中的内容如下
|
||||
|
||||
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field
|
||||
:header: "raw_words", "target", "words", "seq_len"
|
||||
|
||||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160
|
||||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40
|
||||
"...", ., "[...]", .
|
||||
|
||||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::
|
||||
|
||||
+-------------+-----------+--------+-------+---------+
|
||||
| field_names | raw_words | target | words | seq_len |
|
||||
+-------------+-----------+--------+-------+---------+
|
||||
| is_input | False | False | True | True |
|
||||
| is_target | False | True | False | False |
|
||||
| ignore_type | | False | False | False |
|
||||
| pad_value | | 0 | 0 | 0 |
|
||||
+-------------+-----------+--------+-------+---------+
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
|
||||
"""
|
||||
|
||||
:param bool lower: 是否对输入进行小写化。
|
||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
|
||||
"""
|
||||
super().__init__(lower=lower, tokenizer=tokenizer)
|
||||
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3}
|
||||
|
||||
def process_from_file(self, paths=None):
|
||||
"""
|
||||
|
||||
:param str paths:
|
||||
:return: DataBundle
|
||||
"""
|
||||
data_bundle = AGsNewsLoader().load(paths)
|
||||
return self.process(data_bundle=data_bundle)
|
||||
|
||||
|
||||
class SSTPipe(_CLSPipe):
|
||||
"""
|
||||
经过该Pipe之后,DataSet中具备的field如下所示
|
||||
|
Loading…
Reference in New Issue
Block a user