mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
[verify] sst2loader
This commit is contained in:
parent
9130929145
commit
ba47fb851a
@ -1,191 +0,0 @@
|
||||
from typing import Iterable
|
||||
from nltk import Tree
|
||||
from fastNLP.io.base_loader import DataInfo, DataSetLoader
|
||||
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
|
||||
from fastNLP import DataSet
|
||||
from fastNLP import Instance
|
||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
|
||||
import csv
|
||||
from typing import Union, Dict
|
||||
from reproduction.utils import check_dataloader_paths, get_tokenizer
|
||||
|
||||
class SSTLoader(DataSetLoader):
|
||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
|
||||
DATA_DIR = 'sst/'
|
||||
|
||||
"""
|
||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
|
||||
|
||||
读取SST数据集, DataSet包含fields::
|
||||
|
||||
words: list(str) 需要分类的文本
|
||||
target: str 文本的标签
|
||||
|
||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
|
||||
|
||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
|
||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
|
||||
"""
|
||||
|
||||
def __init__(self, subtree=False, fine_grained=False):
|
||||
self.subtree = subtree
|
||||
|
||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
|
||||
'3': 'positive', '4': 'very positive'}
|
||||
if not fine_grained:
|
||||
tag_v['0'] = tag_v['1']
|
||||
tag_v['4'] = tag_v['3']
|
||||
self.tag_v = tag_v
|
||||
|
||||
def _load(self, path):
|
||||
"""
|
||||
|
||||
:param str path: 存储数据的路径
|
||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
|
||||
"""
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
datas = []
|
||||
for l in f:
|
||||
datas.extend([(s, self.tag_v[t])
|
||||
for s, t in self._get_one(l, self.subtree)])
|
||||
ds = DataSet()
|
||||
for words, tag in datas:
|
||||
ds.append(Instance(words=words, target=tag))
|
||||
return ds
|
||||
|
||||
@staticmethod
|
||||
def _get_one(data, subtree):
|
||||
tree = Tree.fromstring(data)
|
||||
if subtree:
|
||||
return [(t.leaves(), t.label()) for t in tree.subtrees()]
|
||||
return [(tree.leaves(), tree.label())]
|
||||
|
||||
def process(self,
|
||||
paths,
|
||||
train_ds: Iterable[str] = None,
|
||||
src_vocab_op: VocabularyOption = None,
|
||||
tgt_vocab_op: VocabularyOption = None,
|
||||
src_embed_op: EmbeddingOption = None):
|
||||
input_name, target_name = 'words', 'target'
|
||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
|
||||
|
||||
info = DataInfo(datasets=self.load(paths))
|
||||
_train_ds = [info.datasets[name]
|
||||
for name in train_ds] if train_ds else info.datasets.values()
|
||||
src_vocab.from_dataset(*_train_ds, field_name=input_name)
|
||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
|
||||
src_vocab.index_dataset(
|
||||
*info.datasets.values(),
|
||||
field_name=input_name, new_field_name=input_name)
|
||||
tgt_vocab.index_dataset(
|
||||
*info.datasets.values(),
|
||||
field_name=target_name, new_field_name=target_name)
|
||||
info.vocabs = {
|
||||
input_name: src_vocab,
|
||||
target_name: tgt_vocab
|
||||
}
|
||||
|
||||
if src_embed_op is not None:
|
||||
src_embed_op.vocab = src_vocab
|
||||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
|
||||
info.embeddings[input_name] = init_emb
|
||||
|
||||
for name, dataset in info.datasets.items():
|
||||
dataset.set_input(input_name)
|
||||
dataset.set_target(target_name)
|
||||
|
||||
return info
|
||||
|
||||
class sst2Loader(DataSetLoader):
|
||||
'''
|
||||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
|
||||
'''
|
||||
def __init__(self):
|
||||
super(sst2Loader, self).__init__()
|
||||
self.tokenizer = get_tokenizer()
|
||||
|
||||
def _load(self, path: str) -> DataSet:
|
||||
ds = DataSet()
|
||||
all_count=0
|
||||
csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t')
|
||||
skip_row = 0
|
||||
for idx,row in enumerate(csv_reader):
|
||||
if idx<=skip_row:
|
||||
continue
|
||||
target = row[1]
|
||||
words=self.tokenizer(row[0])
|
||||
ds.append(Instance(words=words,target=target))
|
||||
all_count+=1
|
||||
print("all count:", all_count)
|
||||
return ds
|
||||
|
||||
def process(self,
|
||||
paths: Union[str, Dict[str, str]],
|
||||
src_vocab_opt: VocabularyOption = None,
|
||||
tgt_vocab_opt: VocabularyOption = None,
|
||||
src_embed_opt: EmbeddingOption = None,
|
||||
char_level_op=False):
|
||||
|
||||
paths = check_dataloader_paths(paths)
|
||||
datasets = {}
|
||||
info = DataInfo()
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
datasets[name] = dataset
|
||||
|
||||
def wordtochar(words):
|
||||
chars = []
|
||||
for word in words:
|
||||
word = word.lower()
|
||||
for char in word:
|
||||
chars.append(char)
|
||||
chars.append('')
|
||||
chars.pop()
|
||||
return chars
|
||||
|
||||
input_name, target_name = 'words', 'target'
|
||||
info.vocabs={}
|
||||
|
||||
# 就分隔为char形式
|
||||
if char_level_op:
|
||||
for dataset in datasets.values():
|
||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
|
||||
|
||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
|
||||
src_vocab.from_dataset(datasets['train'], field_name='words')
|
||||
src_vocab.index_dataset(*datasets.values(), field_name='words')
|
||||
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
|
||||
tgt_vocab.from_dataset(datasets['train'], field_name='target')
|
||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target')
|
||||
|
||||
|
||||
info.vocabs = {
|
||||
"words": src_vocab,
|
||||
"target": tgt_vocab
|
||||
}
|
||||
|
||||
info.datasets = datasets
|
||||
|
||||
|
||||
if src_embed_opt is not None:
|
||||
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
|
||||
info.embeddings['words'] = embed
|
||||
|
||||
return info
|
||||
|
||||
if __name__=="__main__":
|
||||
datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
|
||||
"dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
|
||||
datainfo=sst2Loader().process(datapath,char_level_op=True)
|
||||
#print(datainfo.datasets["train"])
|
||||
len_count = 0
|
||||
for instance in datainfo.datasets["train"]:
|
||||
len_count += len(instance["chars"])
|
||||
|
||||
ave_len = len_count / len(datainfo.datasets["train"])
|
||||
print(ave_len)
|
Loading…
Reference in New Issue
Block a user