mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
大幅度更新:1、更新requirements;2、将modules.aggregator的内容移至modules.encoder;3、将SQuADMetric重命名为ExtractiveQAMetric;4、更新reproduction的README;5、将reproduction/text_classification的dataloader移动到fastNLP.io.data_loader并做适配性修改
This commit is contained in:
parent
f33008a967
commit
d70aa96e4c
@ -37,7 +37,7 @@ __all__ = [
|
||||
|
||||
"AccuracyMetric",
|
||||
"SpanFPreRecMetric",
|
||||
"SQuADMetric",
|
||||
"ExtractiveQAMetric",
|
||||
|
||||
"Optimizer",
|
||||
"SGD",
|
||||
@ -61,3 +61,4 @@ __version__ = '0.4.0'
|
||||
from .core import *
|
||||
from . import models
|
||||
from . import modules
|
||||
from .io import data_loader
|
||||
|
@ -21,7 +21,7 @@ from .dataset import DataSet
|
||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
|
||||
from .instance import Instance
|
||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
|
||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric
|
||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
|
||||
from .optimizer import Optimizer, SGD, Adam
|
||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
|
||||
from .tester import Tester
|
||||
|
@ -6,7 +6,7 @@ __all__ = [
|
||||
"MetricBase",
|
||||
"AccuracyMetric",
|
||||
"SpanFPreRecMetric",
|
||||
"SQuADMetric"
|
||||
"ExtractiveQAMetric"
|
||||
]
|
||||
|
||||
import inspect
|
||||
@ -24,6 +24,7 @@ from .utils import seq_len_to_mask
|
||||
from .vocabulary import Vocabulary
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class MetricBase(object):
|
||||
"""
|
||||
所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。
|
||||
@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1):
|
||||
return y_pred_topk, y_prob_topk
|
||||
|
||||
|
||||
class SQuADMetric(MetricBase):
|
||||
r"""
|
||||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric`
|
||||
class ExtractiveQAMetric(MetricBase):
|
||||
"""
|
||||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric`
|
||||
|
||||
SQuAD数据集metric
|
||||
抽取式QA(如SQuAD)的metric.
|
||||
|
||||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1`
|
||||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2`
|
||||
@ -755,7 +756,7 @@ class SQuADMetric(MetricBase):
|
||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None,
|
||||
beta=1, right_open=True, print_predict_stat=False):
|
||||
|
||||
super(SQuADMetric, self).__init__()
|
||||
super(ExtractiveQAMetric, self).__init__()
|
||||
|
||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2)
|
||||
|
||||
|
@ -4,16 +4,26 @@
|
||||
这些模块的使用方法如下:
|
||||
"""
|
||||
__all__ = [
|
||||
'SSTLoader',
|
||||
|
||||
'IMDBLoader',
|
||||
'MatchingLoader',
|
||||
'SNLILoader',
|
||||
'MNLILoader',
|
||||
'MTL16Loader',
|
||||
'QNLILoader',
|
||||
'QuoraLoader',
|
||||
'RTELoader',
|
||||
'SSTLoader',
|
||||
'SNLILoader',
|
||||
'YelpLoader',
|
||||
]
|
||||
|
||||
|
||||
from .imdb import IMDBLoader
|
||||
from .matching import MatchingLoader
|
||||
from .mnli import MNLILoader
|
||||
from .mtl import MTL16Loader
|
||||
from .qnli import QNLILoader
|
||||
from .quora import QuoraLoader
|
||||
from .rte import RTELoader
|
||||
from .snli import SNLILoader
|
||||
from .sst import SSTLoader
|
||||
from .matching import MatchingLoader, SNLILoader, \
|
||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader
|
||||
from .yelp import YelpLoader
|
||||
|
96
fastNLP/io/data_loader/imdb.py
Normal file
96
fastNLP/io/data_loader/imdb.py
Normal file
@ -0,0 +1,96 @@
|
||||
|
||||
from typing import Union, Dict
|
||||
|
||||
from ..embed_loader import EmbeddingOption, EmbedLoader
|
||||
from ..base_loader import DataSetLoader, DataInfo
|
||||
from ...core.vocabulary import VocabularyOption, Vocabulary
|
||||
from ...core.dataset import DataSet
|
||||
from ...core.instance import Instance
|
||||
from ...core.const import Const
|
||||
|
||||
from ..utils import get_tokenizer
|
||||
|
||||
|
||||
class IMDBLoader(DataSetLoader):
|
||||
"""
|
||||
读取IMDB数据集,DataSet包含以下fields:
|
||||
|
||||
words: list(str), 需要分类的文本
|
||||
target: str, 文本的标签
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(IMDBLoader, self).__init__()
|
||||
self.tokenizer = get_tokenizer()
|
||||
|
||||
def _load(self, path):
|
||||
dataset = DataSet()
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split('\t')
|
||||
target = parts[0]
|
||||
words = self.tokenizer(parts[1].lower())
|
||||
dataset.append(Instance(words=words, target=target))
|
||||
|
||||
if len(dataset) == 0:
|
||||
raise RuntimeError(f"{path} has no valid data.")
|
||||
|
||||
return dataset
|
||||
|
||||
def process(self,
|
||||
paths: Union[str, Dict[str, str]],
|
||||
src_vocab_opt: VocabularyOption = None,
|
||||
tgt_vocab_opt: VocabularyOption = None,
|
||||
char_level_op=False):
|
||||
|
||||
datasets = {}
|
||||
info = DataInfo()
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
datasets[name] = dataset
|
||||
|
||||
def wordtochar(words):
|
||||
chars = []
|
||||
for word in words:
|
||||
word = word.lower()
|
||||
for char in word:
|
||||
chars.append(char)
|
||||
chars.append('')
|
||||
chars.pop()
|
||||
return chars
|
||||
|
||||
if char_level_op:
|
||||
for dataset in datasets.values():
|
||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
|
||||
|
||||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False)
|
||||
|
||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
|
||||
src_vocab.from_dataset(datasets['train'], field_name='words')
|
||||
|
||||
src_vocab.index_dataset(*datasets.values(), field_name='words')
|
||||
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
|
||||
tgt_vocab.from_dataset(datasets['train'], field_name='target')
|
||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target')
|
||||
|
||||
info.vocabs = {
|
||||
Const.INPUT: src_vocab,
|
||||
Const.TARGET: tgt_vocab
|
||||
}
|
||||
|
||||
info.datasets = datasets
|
||||
|
||||
for name, dataset in info.datasets.items():
|
||||
dataset.set_input(Const.INPUT)
|
||||
dataset.set_target(Const.TARGET)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
@ -5,14 +5,13 @@ from typing import Union, Dict
|
||||
from ...core.const import Const
|
||||
from ...core.vocabulary import Vocabulary
|
||||
from ..base_loader import DataInfo, DataSetLoader
|
||||
from ..dataset_loader import JsonLoader, CSVLoader
|
||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
|
||||
from ...modules.encoder._bert import BertTokenizer
|
||||
|
||||
|
||||
class MatchingLoader(DataSetLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
|
||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader`
|
||||
|
||||
读取Matching任务的数据集
|
||||
|
||||
@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader):
|
||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
|
||||
|
||||
return data_info
|
||||
|
||||
|
||||
class SNLILoader(MatchingLoader, JsonLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
|
||||
|
||||
读取SNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
fields = {
|
||||
'sentence1_binary_parse': Const.INPUTS(0),
|
||||
'sentence2_binary_parse': Const.INPUTS(1),
|
||||
'gold_label': Const.TARGET,
|
||||
}
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'snli_1.0_train.jsonl',
|
||||
'dev': 'snli_1.0_dev.jsonl',
|
||||
'test': 'snli_1.0_test.jsonl'}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
JsonLoader.__init__(self, fields=fields)
|
||||
|
||||
def _load(self, path):
|
||||
ds = JsonLoader._load(self, path)
|
||||
|
||||
parentheses_table = str.maketrans({'(': None, ')': None})
|
||||
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(0))
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(1))
|
||||
ds.drop(lambda x: x[Const.TARGET] == '-')
|
||||
return ds
|
||||
|
||||
|
||||
class RTELoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
|
||||
|
||||
读取RTE数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv' # test set has not label
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
self.fields = {
|
||||
'sentence1': Const.INPUTS(0),
|
||||
'sentence2': Const.INPUTS(1),
|
||||
'label': Const.TARGET,
|
||||
}
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if v in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
for fields in ds.get_all_fields():
|
||||
if Const.INPUT in fields:
|
||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
class QNLILoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
|
||||
|
||||
读取QNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv' # test set has not label
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
self.fields = {
|
||||
'question': Const.INPUTS(0),
|
||||
'sentence': Const.INPUTS(1),
|
||||
'label': Const.TARGET,
|
||||
}
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if v in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
for fields in ds.get_all_fields():
|
||||
if Const.INPUT in fields:
|
||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
class MNLILoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
|
||||
|
||||
读取MNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev_matched': 'dev_matched.tsv',
|
||||
'dev_mismatched': 'dev_mismatched.tsv',
|
||||
'test_matched': 'test_matched.tsv',
|
||||
'test_mismatched': 'test_mismatched.tsv',
|
||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
|
||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
|
||||
|
||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
self.fields = {
|
||||
'sentence1_binary_parse': Const.INPUTS(0),
|
||||
'sentence2_binary_parse': Const.INPUTS(1),
|
||||
'gold_label': Const.TARGET,
|
||||
}
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if k in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
|
||||
if Const.TARGET in ds.get_field_names():
|
||||
if ds[0][Const.TARGET] == 'hidden':
|
||||
ds.delete_field(Const.TARGET)
|
||||
|
||||
parentheses_table = str.maketrans({'(': None, ')': None})
|
||||
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(0))
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(1))
|
||||
if Const.TARGET in ds.get_field_names():
|
||||
ds.drop(lambda x: x[Const.TARGET] == '-')
|
||||
return ds
|
||||
|
||||
|
||||
class QuoraLoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
|
||||
|
||||
读取MNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv',
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
return ds
|
||||
|
60
fastNLP/io/data_loader/mnli.py
Normal file
60
fastNLP/io/data_loader/mnli.py
Normal file
@ -0,0 +1,60 @@
|
||||
|
||||
from ...core import Const
|
||||
|
||||
from .matching import MatchingLoader
|
||||
from ..dataset_loader import CSVLoader
|
||||
|
||||
|
||||
class MNLILoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader`
|
||||
|
||||
读取MNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev_matched': 'dev_matched.tsv',
|
||||
'dev_mismatched': 'dev_mismatched.tsv',
|
||||
'test_matched': 'test_matched.tsv',
|
||||
'test_mismatched': 'test_mismatched.tsv',
|
||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
|
||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
|
||||
|
||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
self.fields = {
|
||||
'sentence1_binary_parse': Const.INPUTS(0),
|
||||
'sentence2_binary_parse': Const.INPUTS(1),
|
||||
'gold_label': Const.TARGET,
|
||||
}
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if k in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
|
||||
if Const.TARGET in ds.get_field_names():
|
||||
if ds[0][Const.TARGET] == 'hidden':
|
||||
ds.delete_field(Const.TARGET)
|
||||
|
||||
parentheses_table = str.maketrans({'(': None, ')': None})
|
||||
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(0))
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(1))
|
||||
if Const.TARGET in ds.get_field_names():
|
||||
ds.drop(lambda x: x[Const.TARGET] == '-')
|
||||
return ds
|
65
fastNLP/io/data_loader/mtl.py
Normal file
65
fastNLP/io/data_loader/mtl.py
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
from typing import Union, Dict
|
||||
|
||||
from ..base_loader import DataInfo
|
||||
from ..dataset_loader import CSVLoader
|
||||
from ...core.vocabulary import Vocabulary, VocabularyOption
|
||||
from ...core.const import Const
|
||||
from ..utils import check_dataloader_paths
|
||||
|
||||
|
||||
class MTL16Loader(CSVLoader):
|
||||
"""
|
||||
读取MTL16数据集,DataSet包含以下fields:
|
||||
|
||||
words: list(str), 需要分类的文本
|
||||
target: str, 文本的标签
|
||||
|
||||
数据来源:https://pan.baidu.com/s/1c2L6vdA
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t')
|
||||
|
||||
def _load(self, path):
|
||||
dataset = super(MTL16Loader, self)._load(path)
|
||||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT)
|
||||
if len(dataset) == 0:
|
||||
raise RuntimeError(f"{path} has no valid data.")
|
||||
|
||||
return dataset
|
||||
|
||||
def process(self,
|
||||
paths: Union[str, Dict[str, str]],
|
||||
src_vocab_opt: VocabularyOption = None,
|
||||
tgt_vocab_opt: VocabularyOption = None,):
|
||||
|
||||
paths = check_dataloader_paths(paths)
|
||||
datasets = {}
|
||||
info = DataInfo()
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
datasets[name] = dataset
|
||||
|
||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
|
||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
|
||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)
|
||||
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
|
||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET)
|
||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET)
|
||||
|
||||
info.vocabs = {
|
||||
Const.INPUT: src_vocab,
|
||||
Const.TARGET: tgt_vocab
|
||||
}
|
||||
|
||||
info.datasets = datasets
|
||||
|
||||
for name, dataset in info.datasets.items():
|
||||
dataset.set_input(Const.INPUT)
|
||||
dataset.set_target(Const.TARGET)
|
||||
|
||||
return info
|
45
fastNLP/io/data_loader/qnli.py
Normal file
45
fastNLP/io/data_loader/qnli.py
Normal file
@ -0,0 +1,45 @@
|
||||
|
||||
from ...core import Const
|
||||
|
||||
from .matching import MatchingLoader
|
||||
from ..dataset_loader import CSVLoader
|
||||
|
||||
|
||||
class QNLILoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader`
|
||||
|
||||
读取QNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv' # test set has not label
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
self.fields = {
|
||||
'question': Const.INPUTS(0),
|
||||
'sentence': Const.INPUTS(1),
|
||||
'label': Const.TARGET,
|
||||
}
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if k in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
for fields in ds.get_all_fields():
|
||||
if Const.INPUT in fields:
|
||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
|
||||
|
||||
return ds
|
32
fastNLP/io/data_loader/quora.py
Normal file
32
fastNLP/io/data_loader/quora.py
Normal file
@ -0,0 +1,32 @@
|
||||
|
||||
from ...core import Const
|
||||
|
||||
from .matching import MatchingLoader
|
||||
from ..dataset_loader import CSVLoader
|
||||
|
||||
|
||||
class QuoraLoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader`
|
||||
|
||||
读取MNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv',
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
return ds
|
45
fastNLP/io/data_loader/rte.py
Normal file
45
fastNLP/io/data_loader/rte.py
Normal file
@ -0,0 +1,45 @@
|
||||
|
||||
from ...core import Const
|
||||
|
||||
from .matching import MatchingLoader
|
||||
from ..dataset_loader import CSVLoader
|
||||
|
||||
|
||||
class RTELoader(MatchingLoader, CSVLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader`
|
||||
|
||||
读取RTE数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源:
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'train.tsv',
|
||||
'dev': 'dev.tsv',
|
||||
'test': 'test.tsv' # test set has not label
|
||||
}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
self.fields = {
|
||||
'sentence1': Const.INPUTS(0),
|
||||
'sentence2': Const.INPUTS(1),
|
||||
'label': Const.TARGET,
|
||||
}
|
||||
CSVLoader.__init__(self, sep='\t')
|
||||
|
||||
def _load(self, path):
|
||||
ds = CSVLoader._load(self, path)
|
||||
|
||||
for k, v in self.fields.items():
|
||||
if k in ds.get_field_names():
|
||||
ds.rename_field(k, v)
|
||||
for fields in ds.get_all_fields():
|
||||
if Const.INPUT in fields:
|
||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
|
||||
|
||||
return ds
|
44
fastNLP/io/data_loader/snli.py
Normal file
44
fastNLP/io/data_loader/snli.py
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
from ...core import Const
|
||||
|
||||
from .matching import MatchingLoader
|
||||
from ..dataset_loader import JsonLoader
|
||||
|
||||
|
||||
class SNLILoader(MatchingLoader, JsonLoader):
|
||||
"""
|
||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader`
|
||||
|
||||
读取SNLI数据集,读取的DataSet包含fields::
|
||||
|
||||
words1: list(str),第一句文本, premise
|
||||
words2: list(str), 第二句文本, hypothesis
|
||||
target: str, 真实标签
|
||||
|
||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
|
||||
"""
|
||||
|
||||
def __init__(self, paths: dict=None):
|
||||
fields = {
|
||||
'sentence1_binary_parse': Const.INPUTS(0),
|
||||
'sentence2_binary_parse': Const.INPUTS(1),
|
||||
'gold_label': Const.TARGET,
|
||||
}
|
||||
paths = paths if paths is not None else {
|
||||
'train': 'snli_1.0_train.jsonl',
|
||||
'dev': 'snli_1.0_dev.jsonl',
|
||||
'test': 'snli_1.0_test.jsonl'}
|
||||
MatchingLoader.__init__(self, paths=paths)
|
||||
JsonLoader.__init__(self, fields=fields)
|
||||
|
||||
def _load(self, path):
|
||||
ds = JsonLoader._load(self, path)
|
||||
|
||||
parentheses_table = str.maketrans({'(': None, ')': None})
|
||||
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(0))
|
||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
|
||||
new_field_name=Const.INPUTS(1))
|
||||
ds.drop(lambda x: x[Const.TARGET] == '-')
|
||||
return ds
|
@ -1,19 +1,19 @@
|
||||
from typing import Iterable
|
||||
|
||||
from typing import Union, Dict
|
||||
from nltk import Tree
|
||||
import spacy
|
||||
|
||||
from ..base_loader import DataInfo, DataSetLoader
|
||||
from ..dataset_loader import CSVLoader
|
||||
from ...core.vocabulary import VocabularyOption, Vocabulary
|
||||
from ...core.dataset import DataSet
|
||||
from ...core.const import Const
|
||||
from ...core.instance import Instance
|
||||
from ..utils import check_dataloader_paths, get_tokenizer
|
||||
|
||||
|
||||
class SSTLoader(DataSetLoader):
|
||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
|
||||
DATA_DIR = 'sst/'
|
||||
|
||||
"""
|
||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
|
||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader`
|
||||
|
||||
读取SST数据集, DataSet包含fields::
|
||||
|
||||
@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader):
|
||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
|
||||
"""
|
||||
|
||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
|
||||
DATA_DIR = 'sst/'
|
||||
|
||||
def __init__(self, subtree=False, fine_grained=False):
|
||||
self.subtree = subtree
|
||||
|
||||
@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader):
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SST2Loader(CSVLoader):
|
||||
"""
|
||||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SST2Loader, self).__init__(sep='\t')
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET}
|
||||
|
||||
def _load(self, path: str) -> DataSet:
|
||||
ds = super(SST2Loader, self)._load(path)
|
||||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT)
|
||||
print("all count:", len(ds))
|
||||
return ds
|
||||
|
||||
def process(self,
|
||||
paths: Union[str, Dict[str, str]],
|
||||
src_vocab_opt: VocabularyOption = None,
|
||||
tgt_vocab_opt: VocabularyOption = None,
|
||||
char_level_op=False):
|
||||
|
||||
paths = check_dataloader_paths(paths)
|
||||
datasets = {}
|
||||
info = DataInfo()
|
||||
for name, path in paths.items():
|
||||
dataset = self.load(path)
|
||||
datasets[name] = dataset
|
||||
|
||||
def wordtochar(words):
|
||||
chars = []
|
||||
for word in words:
|
||||
word = word.lower()
|
||||
for char in word:
|
||||
chars.append(char)
|
||||
chars.append('')
|
||||
chars.pop()
|
||||
return chars
|
||||
|
||||
input_name, target_name = Const.INPUT, Const.TARGET
|
||||
info.vocabs={}
|
||||
|
||||
# 就分隔为char形式
|
||||
if char_level_op:
|
||||
for dataset in datasets.values():
|
||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
|
||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
|
||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
|
||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)
|
||||
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
|
||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET)
|
||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET)
|
||||
|
||||
info.vocabs = {
|
||||
Const.INPUT: src_vocab,
|
||||
Const.TARGET: tgt_vocab
|
||||
}
|
||||
|
||||
info.datasets = datasets
|
||||
|
||||
for name, dataset in info.datasets.items():
|
||||
dataset.set_input(Const.INPUT)
|
||||
dataset.set_target(Const.TARGET)
|
||||
|
||||
return info
|
||||
|
||||
|
126
fastNLP/io/data_loader/yelp.py
Normal file
126
fastNLP/io/data_loader/yelp.py
Normal file
@ -0,0 +1,126 @@
|
||||
|
||||
import csv
|
||||
from typing import Iterable
|
||||
|
||||
from ...core.const import Const
|
||||
from ...core import DataSet, Instance, Vocabulary
|
||||
from ...core.vocabulary import VocabularyOption
|
||||
from ..base_loader import DataInfo,DataSetLoader
|
||||
from typing import Union, Dict
|
||||
from ..utils import check_dataloader_paths, get_tokenizer
|
||||
|
||||
|
||||
class YelpLoader(DataSetLoader):
|
||||
"""
|
||||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields:
|
||||
words: list(str), 需要分类的文本
|
||||
target: str, 文本的标签
|
||||
chars:list(str),未index的字符列表
|
||||
|
||||
数据集:yelp_full/yelp_polarity
|
||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
|
||||
:param lower: 是否需要自动转小写,默认为False。
|
||||
"""
|
||||
|
||||
def __init__(self, fine_grained=False, lower=False):
|
||||
super(YelpLoader, self).__init__()
|
||||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
|
||||
'4.0': 'positive', '5.0': 'very positive'}
|
||||
if not fine_grained:
|
||||
tag_v['1.0'] = tag_v['2.0']
|
||||
tag_v['5.0'] = tag_v['4.0']
|
||||
self.fine_grained = fine_grained
|
||||
self.tag_v = tag_v
|
||||
self.lower = lower
|
||||
self.tokenizer = get_tokenizer()
|
||||
|
||||
def _load(self, path):
|
||||
ds = DataSet()
|
||||
csv_reader = csv.reader(open(path, encoding='utf-8'))
|
||||
all_count = 0
|
||||
real_count = 0
|
||||
for row in csv_reader:
|
||||
all_count += 1
|
||||
if len(row) == 2:
|
||||
target = self.tag_v[row[0] + ".0"]
|
||||
words = clean_str(row[1], self.tokenizer, self.lower)
|
||||
if len(words) != 0:
|
||||
ds.append(Instance(words=words, target=target))
|
||||
real_count += 1
|
||||
print("all count:", all_count)
|
||||
print("real count:", real_count)
|
||||
return ds
|
||||
|
||||
def process(self, paths: Union[str, Dict[str, str]],
|
||||
train_ds: Iterable[str] = None,
|
||||
src_vocab_op: VocabularyOption = None,
|
||||
tgt_vocab_op: VocabularyOption = None,
|
||||
char_level_op=False):
|
||||
paths = check_dataloader_paths(paths)
|
||||
info = DataInfo(datasets=self.load(paths))
|
||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
|
||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \
|
||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
|
||||
_train_ds = [info.datasets[name]
|
||||
for name in train_ds] if train_ds else info.datasets.values()
|
||||
|
||||
def wordtochar(words):
|
||||
chars = []
|
||||
for word in words:
|
||||
word = word.lower()
|
||||
for char in word:
|
||||
chars.append(char)
|
||||
chars.append('')
|
||||
chars.pop()
|
||||
return chars
|
||||
|
||||
input_name, target_name = Const.INPUT, Const.TARGET
|
||||
info.vocabs = {}
|
||||
# 就分隔为char形式
|
||||
if char_level_op:
|
||||
for dataset in info.datasets.values():
|
||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
|
||||
else:
|
||||
src_vocab.from_dataset(*_train_ds, field_name=input_name)
|
||||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name)
|
||||
info.vocabs[input_name] = src_vocab
|
||||
|
||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
|
||||
tgt_vocab.index_dataset(
|
||||
*info.datasets.values(),
|
||||
field_name=target_name, new_field_name=target_name)
|
||||
|
||||
info.vocabs[target_name] = tgt_vocab
|
||||
|
||||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False)
|
||||
|
||||
for name, dataset in info.datasets.items():
|
||||
dataset.set_input(Const.INPUT)
|
||||
dataset.set_target(Const.TARGET)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def clean_str(sentence, tokenizer, char_lower=False):
|
||||
"""
|
||||
heavily borrowed from github
|
||||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
|
||||
:param sentence: is a str
|
||||
:return:
|
||||
"""
|
||||
if char_lower:
|
||||
sentence = sentence.lower()
|
||||
import re
|
||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
|
||||
words = tokenizer(sentence)
|
||||
words_collection = []
|
||||
for word in words:
|
||||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
|
||||
continue
|
||||
tt = nonalpnum.split(word)
|
||||
t = ''.join(tt)
|
||||
if t != '':
|
||||
words_collection.append(t)
|
||||
|
||||
return words_collection
|
||||
|
@ -1,14 +0,0 @@
|
||||
__all__ = [
|
||||
"MaxPool",
|
||||
"MaxPoolWithMask",
|
||||
"AvgPool",
|
||||
|
||||
"MultiHeadAttention",
|
||||
]
|
||||
|
||||
from .pooling import MaxPool
|
||||
from .pooling import MaxPoolWithMask
|
||||
from .pooling import AvgPool
|
||||
from .pooling import AvgPoolWithMask
|
||||
|
||||
from .attention import MultiHeadAttention
|
@ -8,9 +8,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..dropout import TimestepDropout
|
||||
from fastNLP.modules.dropout import TimestepDropout
|
||||
|
||||
from ..utils import initial_parameter
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
|
||||
|
||||
class DotAttention(nn.Module):
|
@ -3,7 +3,7 @@ __all__ = [
|
||||
]
|
||||
from torch import nn
|
||||
|
||||
from ..aggregator.attention import MultiHeadAttention
|
||||
from fastNLP.modules.encoder.attention import MultiHeadAttention
|
||||
from ..dropout import TimestepDropout
|
||||
|
||||
|
||||
|
@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达
|
||||
由于版权问题,本文无法提供数据集的下载,请自行下载。
|
||||
原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。
|
||||
|
||||
代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。
|
||||
代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。
|
||||
处理之后的数据集为json格式,例子:
|
||||
```
|
||||
{
|
||||
@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达
|
||||
### embedding 数据集下载
|
||||
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt)
|
||||
|
||||
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip)
|
||||
[glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip)
|
||||
|
||||
|
||||
|
||||
## 运行
|
||||
```python
|
||||
```shell
|
||||
# 训练代码
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py
|
||||
# 测试代码
|
||||
@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py
|
||||
|
||||
## 结果
|
||||
原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。
|
||||
其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。
|
||||
其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。
|
||||
|
||||
在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。
|
||||
在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。
|
||||
|
||||
|
||||
## 问题
|
@ -2,7 +2,7 @@
|
||||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%).
|
||||
|
||||
复现的模型有(按论文发表时间顺序排序):
|
||||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
|
||||
- CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py).
|
||||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844).
|
||||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py).
|
||||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf).
|
||||
@ -21,7 +21,7 @@
|
||||
|
||||
model name | SNLI | MNLI | RTE | QNLI | Quora
|
||||
:---: | :---: | :---: | :---: | :---: | :---:
|
||||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - |
|
||||
CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - |
|
||||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - |
|
||||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 |
|
||||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 |
|
||||
@ -44,7 +44,7 @@ Performance on Test set:
|
||||
|
||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large
|
||||
:---: | :---: | :---: | :---: | :---: | :---: | :---:
|
||||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16
|
||||
__performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16
|
||||
|
||||
## MNLI
|
||||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/)
|
||||
@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched):
|
||||
|
||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base
|
||||
:---: | :---: | :---: | :---: | :---: | :---: |
|
||||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - |
|
||||
__performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - |
|
||||
|
||||
|
||||
## RTE
|
||||
@ -92,7 +92,7 @@ Performance on __Dev__ set:
|
||||
|
||||
model name | CNTN | ESIM | DIIN | MwAN | BERT
|
||||
:---: | :---: | :---: | :---: | :---: | :---:
|
||||
__performance__ | - | 76.97 | - | 74.6 | -
|
||||
__performance__ | 62.38 | 76.97 | - | 74.6 | -
|
||||
|
||||
## Quora
|
||||
|
||||
|
@ -3,3 +3,5 @@ torch>=1.0.0
|
||||
tqdm>=4.28.1
|
||||
nltk>=3.4.1
|
||||
requests
|
||||
spacy
|
||||
h5py
|
||||
|
Loading…
Reference in New Issue
Block a user