From ba47fb851a588c211e1c68d5b8f8ca878029ec74 Mon Sep 17 00:00:00 2001 From: wyg <1505116161@qq.com> Date: Mon, 8 Jul 2019 19:45:39 +0800 Subject: [PATCH] [verify] sst2loader --- .../text_classification/data/sstLoader.py | 191 ------------------ 1 file changed, 191 deletions(-) delete mode 100644 reproduction/text_classification/data/sstLoader.py diff --git a/reproduction/text_classification/data/sstLoader.py b/reproduction/text_classification/data/sstLoader.py deleted file mode 100644 index 14524ea5..00000000 --- a/reproduction/text_classification/data/sstLoader.py +++ /dev/null @@ -1,191 +0,0 @@ -from typing import Iterable -from nltk import Tree -from fastNLP.io.base_loader import DataInfo, DataSetLoader -from fastNLP.core.vocabulary import VocabularyOption, Vocabulary -from fastNLP import DataSet -from fastNLP import Instance -from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader -import csv -from typing import Union, Dict -from reproduction.utils import check_dataloader_paths, get_tokenizer - -class SSTLoader(DataSetLoader): - URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' - DATA_DIR = 'sst/' - - """ - 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` - - 读取SST数据集, DataSet包含fields:: - - words: list(str) 需要分类的文本 - target: str 文本的标签 - - 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip - - :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` - :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` - """ - - def __init__(self, subtree=False, fine_grained=False): - self.subtree = subtree - - tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', - '3': 'positive', '4': 'very positive'} - if not fine_grained: - tag_v['0'] = tag_v['1'] - tag_v['4'] = tag_v['3'] - self.tag_v = tag_v - - def _load(self, path): - """ - - :param str path: 存储数据的路径 - :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - datas = [] - for l in f: - datas.extend([(s, self.tag_v[t]) - for s, t in self._get_one(l, self.subtree)]) - ds = DataSet() - for words, tag in datas: - ds.append(Instance(words=words, target=tag)) - return ds - - @staticmethod - def _get_one(data, subtree): - tree = Tree.fromstring(data) - if subtree: - return [(t.leaves(), t.label()) for t in tree.subtrees()] - return [(tree.leaves(), tree.label())] - - def process(self, - paths, - train_ds: Iterable[str] = None, - src_vocab_op: VocabularyOption = None, - tgt_vocab_op: VocabularyOption = None, - src_embed_op: EmbeddingOption = None): - input_name, target_name = 'words', 'target' - src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) - tgt_vocab = Vocabulary(unknown=None, padding=None) \ - if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) - - info = DataInfo(datasets=self.load(paths)) - _train_ds = [info.datasets[name] - for name in train_ds] if train_ds else info.datasets.values() - src_vocab.from_dataset(*_train_ds, field_name=input_name) - tgt_vocab.from_dataset(*_train_ds, field_name=target_name) - src_vocab.index_dataset( - *info.datasets.values(), - field_name=input_name, new_field_name=input_name) - tgt_vocab.index_dataset( - *info.datasets.values(), - field_name=target_name, new_field_name=target_name) - info.vocabs = { - input_name: src_vocab, - target_name: tgt_vocab - } - - if src_embed_op is not None: - src_embed_op.vocab = src_vocab - init_emb = EmbedLoader.load_with_vocab(**src_embed_op) - info.embeddings[input_name] = init_emb - - for name, dataset in info.datasets.items(): - dataset.set_input(input_name) - dataset.set_target(target_name) - - return info - -class sst2Loader(DataSetLoader): - ''' - 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', - ''' - def __init__(self): - super(sst2Loader, self).__init__() - self.tokenizer = get_tokenizer() - - def _load(self, path: str) -> DataSet: - ds = DataSet() - all_count=0 - csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') - skip_row = 0 - for idx,row in enumerate(csv_reader): - if idx<=skip_row: - continue - target = row[1] - words=self.tokenizer(row[0]) - ds.append(Instance(words=words,target=target)) - all_count+=1 - print("all count:", all_count) - return ds - - def process(self, - paths: Union[str, Dict[str, str]], - src_vocab_opt: VocabularyOption = None, - tgt_vocab_opt: VocabularyOption = None, - src_embed_opt: EmbeddingOption = None, - char_level_op=False): - - paths = check_dataloader_paths(paths) - datasets = {} - info = DataInfo() - for name, path in paths.items(): - dataset = self.load(path) - datasets[name] = dataset - - def wordtochar(words): - chars = [] - for word in words: - word = word.lower() - for char in word: - chars.append(char) - chars.append('') - chars.pop() - return chars - - input_name, target_name = 'words', 'target' - info.vocabs={} - - # 就分隔为char形式 - if char_level_op: - for dataset in datasets.values(): - dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') - - src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) - src_vocab.from_dataset(datasets['train'], field_name='words') - src_vocab.index_dataset(*datasets.values(), field_name='words') - - tgt_vocab = Vocabulary(unknown=None, padding=None) \ - if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) - tgt_vocab.from_dataset(datasets['train'], field_name='target') - tgt_vocab.index_dataset(*datasets.values(), field_name='target') - - - info.vocabs = { - "words": src_vocab, - "target": tgt_vocab - } - - info.datasets = datasets - - - if src_embed_opt is not None: - embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) - info.embeddings['words'] = embed - - return info - -if __name__=="__main__": - datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", - "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} - datainfo=sst2Loader().process(datapath,char_level_op=True) - #print(datainfo.datasets["train"]) - len_count = 0 - for instance in datainfo.datasets["train"]: - len_count += len(instance["chars"]) - - ave_len = len_count / len(datainfo.datasets["train"]) - print(ave_len) \ No newline at end of file