大幅度更新:1、更新requirements;2、将modules.aggregator的内容移至modules.encoder;3、将SQuADMetric重命名为ExtractiveQAMetric;4、更新reproduction的README;5、将reproduction/text_classification的dataloader移动到fastNLP.io.data_loader并做适配性修改

This commit is contained in:
xuyige 2019-07-09 01:01:47 +08:00
parent f33008a967
commit d70aa96e4c
21 changed files with 632 additions and 249 deletions

View File

@ -37,7 +37,7 @@ __all__ = [
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
@ -61,3 +61,4 @@ __version__ = '0.4.0'
from .core import *
from . import models
from . import modules
from .io import data_loader

View File

@ -21,7 +21,7 @@ from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester

View File

@ -6,7 +6,7 @@ __all__ = [
"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric"
"ExtractiveQAMetric"
]
import inspect
@ -24,6 +24,7 @@ from .utils import seq_len_to_mask
from .vocabulary import Vocabulary
from abc import abstractmethod
class MetricBase(object):
"""
所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象需要覆盖写入evaluate(), get_metric()方法
@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1):
return y_pred_topk, y_prob_topk
class SQuADMetric(MetricBase):
r"""
别名:class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric`
class ExtractiveQAMetric(MetricBase):
"""
别名:class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric`
SQuAD数据集metric
抽取式QA如SQuAD的metric.
:param pred1: 参数映射表中 `pred1` 的映射关系None表示映射关系为 `pred1` -> `pred1`
:param pred2: 参数映射表中 `pred2` 的映射关系None表示映射关系为 `pred2` -> `pred2`
@ -755,7 +756,7 @@ class SQuADMetric(MetricBase):
def __init__(self, pred1=None, pred2=None, target1=None, target2=None,
beta=1, right_open=True, print_predict_stat=False):
super(SQuADMetric, self).__init__()
super(ExtractiveQAMetric, self).__init__()
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2)

View File

@ -4,16 +4,26 @@
这些模块的使用方法如下:
"""
__all__ = [
'SSTLoader',
'IMDBLoader',
'MatchingLoader',
'SNLILoader',
'MNLILoader',
'MTL16Loader',
'QNLILoader',
'QuoraLoader',
'RTELoader',
'SSTLoader',
'SNLILoader',
'YelpLoader',
]
from .imdb import IMDBLoader
from .matching import MatchingLoader
from .mnli import MNLILoader
from .mtl import MTL16Loader
from .qnli import QNLILoader
from .quora import QuoraLoader
from .rte import RTELoader
from .snli import SNLILoader
from .sst import SSTLoader
from .matching import MatchingLoader, SNLILoader, \
MNLILoader, QNLILoader, QuoraLoader, RTELoader
from .yelp import YelpLoader

View File

@ -0,0 +1,96 @@
from typing import Union, Dict
from ..embed_loader import EmbeddingOption, EmbedLoader
from ..base_loader import DataSetLoader, DataInfo
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.const import Const
from ..utils import get_tokenizer
class IMDBLoader(DataSetLoader):
"""
读取IMDB数据集DataSet包含以下fields:
words: list(str), 需要分类的文本
target: str, 文本的标签
"""
def __init__(self):
super(IMDBLoader, self).__init__()
self.tokenizer = get_tokenizer()
def _load(self, path):
dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
target = parts[0]
words = self.tokenizer(parts[1].lower())
dataset.append(Instance(words=words, target=target))
if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.")
return dataset
def process(self,
paths: Union[str, Dict[str, str]],
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,
char_level_op=False):
datasets = {}
info = DataInfo()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset
def wordtochar(words):
chars = []
for word in words:
word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
return chars
if char_level_op:
for dataset in datasets.values():
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False)
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name='words')
src_vocab.index_dataset(*datasets.values(), field_name='words')
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
tgt_vocab.from_dataset(datasets['train'], field_name='target')
tgt_vocab.index_dataset(*datasets.values(), field_name='target')
info.vocabs = {
Const.INPUT: src_vocab,
Const.TARGET: tgt_vocab
}
info.datasets = datasets
for name, dataset in info.datasets.items():
dataset.set_input(Const.INPUT)
dataset.set_target(Const.TARGET)
return info

View File

@ -5,14 +5,13 @@ from typing import Union, Dict
from ...core.const import Const
from ...core.vocabulary import Vocabulary
from ..base_loader import DataInfo, DataSetLoader
from ..dataset_loader import JsonLoader, CSVLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder._bert import BertTokenizer
class MatchingLoader(DataSetLoader):
"""
别名:class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
别名:class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader`
读取Matching任务的数据集
@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader):
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
return data_info
class SNLILoader(MatchingLoader, JsonLoader):
"""
别名:class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
读取SNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""
def __init__(self, paths: dict=None):
fields = {
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
paths = paths if paths is not None else {
'train': 'snli_1.0_train.jsonl',
'dev': 'snli_1.0_dev.jsonl',
'test': 'snli_1.0_test.jsonl'}
MatchingLoader.__init__(self, paths=paths)
JsonLoader.__init__(self, fields=fields)
def _load(self, path):
ds = JsonLoader._load(self, path)
parentheses_table = str.maketrans({'(': None, ')': None})
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds
class RTELoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
读取RTE数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'sentence1': Const.INPUTS(0),
'sentence2': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
return ds
class QNLILoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
读取QNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'question': Const.INPUTS(0),
'sentence': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if v in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
return ds
class MNLILoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
读取MNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev_matched': 'dev_matched.tsv',
'dev_mismatched': 'dev_mismatched.tsv',
'test_matched': 'test_matched.tsv',
'test_mismatched': 'test_mismatched.tsv',
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
# test_0.9_mathed与mismatched是MNLI0.9版本的数据来源kaggle
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t')
self.fields = {
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
if Const.TARGET in ds.get_field_names():
if ds[0][Const.TARGET] == 'hidden':
ds.delete_field(Const.TARGET)
parentheses_table = str.maketrans({'(': None, ')': None})
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
if Const.TARGET in ds.get_field_names():
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds
class QuoraLoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
读取MNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv',
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
def _load(self, path):
ds = CSVLoader._load(self, path)
return ds

View File

@ -0,0 +1,60 @@
from ...core import Const
from .matching import MatchingLoader
from ..dataset_loader import CSVLoader
class MNLILoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader`
读取MNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev_matched': 'dev_matched.tsv',
'dev_mismatched': 'dev_mismatched.tsv',
'test_matched': 'test_matched.tsv',
'test_mismatched': 'test_mismatched.tsv',
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
# test_0.9_mathed与mismatched是MNLI0.9版本的数据来源kaggle
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t')
self.fields = {
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
if Const.TARGET in ds.get_field_names():
if ds[0][Const.TARGET] == 'hidden':
ds.delete_field(Const.TARGET)
parentheses_table = str.maketrans({'(': None, ')': None})
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
if Const.TARGET in ds.get_field_names():
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds

View File

@ -0,0 +1,65 @@
from typing import Union, Dict
from ..base_loader import DataInfo
from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const
from ..utils import check_dataloader_paths
class MTL16Loader(CSVLoader):
"""
读取MTL16数据集DataSet包含以下fields:
words: list(str), 需要分类的文本
target: str, 文本的标签
数据来源https://pan.baidu.com/s/1c2L6vdA
"""
def __init__(self):
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t')
def _load(self, path):
dataset = super(MTL16Loader, self)._load(path)
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT)
if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.")
return dataset
def process(self,
paths: Union[str, Dict[str, str]],
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,):
paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET)
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET)
info.vocabs = {
Const.INPUT: src_vocab,
Const.TARGET: tgt_vocab
}
info.datasets = datasets
for name, dataset in info.datasets.items():
dataset.set_input(Const.INPUT)
dataset.set_target(Const.TARGET)
return info

View File

@ -0,0 +1,45 @@
from ...core import Const
from .matching import MatchingLoader
from ..dataset_loader import CSVLoader
class QNLILoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader`
读取QNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'question': Const.INPUTS(0),
'sentence': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
return ds

View File

@ -0,0 +1,32 @@
from ...core import Const
from .matching import MatchingLoader
from ..dataset_loader import CSVLoader
class QuoraLoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader`
读取MNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv',
}
MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
def _load(self, path):
ds = CSVLoader._load(self, path)
return ds

View File

@ -0,0 +1,45 @@
from ...core import Const
from .matching import MatchingLoader
from ..dataset_loader import CSVLoader
class RTELoader(MatchingLoader, CSVLoader):
"""
别名:class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader`
读取RTE数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源:
"""
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
'sentence1': Const.INPUTS(0),
'sentence2': Const.INPUTS(1),
'label': Const.TARGET,
}
CSVLoader.__init__(self, sep='\t')
def _load(self, path):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
if k in ds.get_field_names():
ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
return ds

View File

@ -0,0 +1,44 @@
from ...core import Const
from .matching import MatchingLoader
from ..dataset_loader import JsonLoader
class SNLILoader(MatchingLoader, JsonLoader):
"""
别名:class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader`
读取SNLI数据集读取的DataSet包含fields::
words1: list(str)第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""
def __init__(self, paths: dict=None):
fields = {
'sentence1_binary_parse': Const.INPUTS(0),
'sentence2_binary_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
paths = paths if paths is not None else {
'train': 'snli_1.0_train.jsonl',
'dev': 'snli_1.0_dev.jsonl',
'test': 'snli_1.0_test.jsonl'}
MatchingLoader.__init__(self, paths=paths)
JsonLoader.__init__(self, fields=fields)
def _load(self, path):
ds = JsonLoader._load(self, path)
parentheses_table = str.maketrans({'(': None, ')': None})
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds

View File

@ -1,19 +1,19 @@
from typing import Iterable
from typing import Union, Dict
from nltk import Tree
import spacy
from ..base_loader import DataInfo, DataSetLoader
from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.const import Const
from ...core.instance import Instance
from ..utils import check_dataloader_paths, get_tokenizer
class SSTLoader(DataSetLoader):
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
DATA_DIR = 'sst/'
"""
别名:class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
别名:class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader`
读取SST数据集, DataSet包含fields::
@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader):
:param fine_grained: 是否使用SST-5标准 ``False`` , 使用SST-2Default: ``False``
"""
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
DATA_DIR = 'sst/'
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree
@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader):
return info
class SST2Loader(CSVLoader):
"""
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"""
def __init__(self):
super(SST2Loader, self).__init__(sep='\t')
self.tokenizer = get_tokenizer()
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET}
def _load(self, path: str) -> DataSet:
ds = super(SST2Loader, self)._load(path)
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT)
print("all count:", len(ds))
return ds
def process(self,
paths: Union[str, Dict[str, str]],
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,
char_level_op=False):
paths = check_dataloader_paths(paths)
datasets = {}
info = DataInfo()
for name, path in paths.items():
dataset = self.load(path)
datasets[name] = dataset
def wordtochar(words):
chars = []
for word in words:
word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
return chars
input_name, target_name = Const.INPUT, Const.TARGET
info.vocabs={}
# 就分隔为char形式
if char_level_op:
for dataset in datasets.values():
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET)
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET)
info.vocabs = {
Const.INPUT: src_vocab,
Const.TARGET: tgt_vocab
}
info.datasets = datasets
for name, dataset in info.datasets.items():
dataset.set_input(Const.INPUT)
dataset.set_target(Const.TARGET)
return info

View File

@ -0,0 +1,126 @@
import csv
from typing import Iterable
from ...core.const import Const
from ...core import DataSet, Instance, Vocabulary
from ...core.vocabulary import VocabularyOption
from ..base_loader import DataInfo,DataSetLoader
from typing import Union, Dict
from ..utils import check_dataloader_paths, get_tokenizer
class YelpLoader(DataSetLoader):
"""
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields:
words: list(str), 需要分类的文本
target: str, 文本的标签
chars:list(str),未index的字符列表
数据集yelp_full/yelp_polarity
:param fine_grained: 是否使用SST-5标准 ``False`` , 使用SST-2Default: ``False``
:param lower: 是否需要自动转小写默认为False
"""
def __init__(self, fine_grained=False, lower=False):
super(YelpLoader, self).__init__()
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
'4.0': 'positive', '5.0': 'very positive'}
if not fine_grained:
tag_v['1.0'] = tag_v['2.0']
tag_v['5.0'] = tag_v['4.0']
self.fine_grained = fine_grained
self.tag_v = tag_v
self.lower = lower
self.tokenizer = get_tokenizer()
def _load(self, path):
ds = DataSet()
csv_reader = csv.reader(open(path, encoding='utf-8'))
all_count = 0
real_count = 0
for row in csv_reader:
all_count += 1
if len(row) == 2:
target = self.tag_v[row[0] + ".0"]
words = clean_str(row[1], self.tokenizer, self.lower)
if len(words) != 0:
ds.append(Instance(words=words, target=target))
real_count += 1
print("all count:", all_count)
print("real count:", real_count)
return ds
def process(self, paths: Union[str, Dict[str, str]],
train_ds: Iterable[str] = None,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
char_level_op=False):
paths = check_dataloader_paths(paths)
info = DataInfo(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()
def wordtochar(words):
chars = []
for word in words:
word = word.lower()
for char in word:
chars.append(char)
chars.append('')
chars.pop()
return chars
input_name, target_name = Const.INPUT, Const.TARGET
info.vocabs = {}
# 就分隔为char形式
if char_level_op:
for dataset in info.datasets.values():
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
else:
src_vocab.from_dataset(*_train_ds, field_name=input_name)
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name)
info.vocabs[input_name] = src_vocab
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
tgt_vocab.index_dataset(
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs[target_name] = tgt_vocab
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False)
for name, dataset in info.datasets.items():
dataset.set_input(Const.INPUT)
dataset.set_target(Const.TARGET)
return info
def clean_str(sentence, tokenizer, char_lower=False):
"""
heavily borrowed from github
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
:param sentence: is a str
:return:
"""
if char_lower:
sentence = sentence.lower()
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
words = tokenizer(sentence)
words_collection = []
for word in words:
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
continue
tt = nonalpnum.split(word)
t = ''.join(tt)
if t != '':
words_collection.append(t)
return words_collection

View File

@ -1,14 +0,0 @@
__all__ = [
"MaxPool",
"MaxPoolWithMask",
"AvgPool",
"MultiHeadAttention",
]
from .pooling import MaxPool
from .pooling import MaxPoolWithMask
from .pooling import AvgPool
from .pooling import AvgPoolWithMask
from .attention import MultiHeadAttention

View File

@ -8,9 +8,9 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..dropout import TimestepDropout
from fastNLP.modules.dropout import TimestepDropout
from ..utils import initial_parameter
from fastNLP.modules.utils import initial_parameter
class DotAttention(nn.Module):

View File

@ -3,7 +3,7 @@ __all__ = [
]
from torch import nn
from ..aggregator.attention import MultiHeadAttention
from fastNLP.modules.encoder.attention import MultiHeadAttention
from ..dropout import TimestepDropout

View File

@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达
由于版权问题,本文无法提供数据集的下载,请自行下载。
原始数据集的格式为conll格式详细介绍参考数据集给出的官方介绍页面。
代码实现采用了论文作者Lee的预处理方法具体细节参[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。
代码实现采用了论文作者Lee的预处理方法具体细节参[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。
处理之后的数据集为json格式例子
```
{
@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达
### embedding 数据集下载
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt)
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip)
[glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip)
## 运行
```python
```shell
# 训练代码
CUDA_VISIBLE_DEVICES=0 python train.py
# 测试代码
@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py
## 结果
原论文作者在测试集上取得了67.2%的结果AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。
其中allenNLP训练时没有加入speaker信息没有variational dropout以及只使用了100的antecedents而不是250。
其中AllenNLP训练时没有加入speaker信息没有variational dropout以及只使用了100的antecedents而不是250。
在与allenNLP使用同样的超参和配置时本代码复现取得了63.6%的F1值。
在与AllenNLP使用同样的超参和配置时本代码复现取得了63.6%的F1值。
## 问题

View File

@ -2,7 +2,7 @@
这里使用fastNLP复现了几个著名的Matching任务的模型旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%).
复现的模型有(按论文发表时间顺序排序):
- CNTN模型代码(still in progress)[](); 训练代码(still in progress)[]().
- CNTN[模型代码](model/cntn.py); [训练代码](matching_cntn.py).
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844).
- ESIM[模型代码](model/esim.py); [训练代码](matching_esim.py).
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf).
@ -21,7 +21,7 @@
model name | SNLI | MNLI | RTE | QNLI | Quora
:---: | :---: | :---: | :---: | :---: | :---:
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - |
CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - |
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - |
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 |
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 |
@ -44,7 +44,7 @@ Performance on Test set:
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large
:---: | :---: | :---: | :---: | :---: | :---: | :---:
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16
__performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16
## MNLI
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/)
@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched):
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base
:---: | :---: | :---: | :---: | :---: | :---: |
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - |
__performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - |
## RTE
@ -92,7 +92,7 @@ Performance on __Dev__ set:
model name | CNTN | ESIM | DIIN | MwAN | BERT
:---: | :---: | :---: | :---: | :---: | :---:
__performance__ | - | 76.97 | - | 74.6 | -
__performance__ | 62.38 | 76.97 | - | 74.6 | -
## Quora

View File

@ -3,3 +3,5 @@ torch>=1.0.0
tqdm>=4.28.1
nltk>=3.4.1
requests
spacy
h5py