增加中文分类Pipe;使用矩阵加速BertEmbedding部分pool_method;调整部分测试用例名称;修复metric中对warning的误报

This commit is contained in:
yh_cc 2019-09-04 12:47:52 +08:00
parent d15ad75d96
commit e903db0e70
20 changed files with 274 additions and 60 deletions

View File

@ -238,8 +238,8 @@ class CrossEntropyLoss(LossBase):
pred = pred.tranpose(-1, pred)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None:
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0)
if seq_len is not None and target.dim()>1:
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0)
target = target.masked_fill(mask, self.padding_idx)
return F.cross_entropy(input=pred, target=target,

View File

@ -347,7 +347,7 @@ class AccuracyMetric(MetricBase):
pass
elif pred.dim() == target.dim() + 1:
pred = pred.argmax(dim=-1)
if seq_len is None:
if seq_len is None and target.dim()>1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "

View File

@ -68,7 +68,7 @@ class BertEmbedding(ContextualEmbedding):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False):
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
@ -165,7 +165,7 @@ class BertWordPieceEncoder(nn.Module):
"""
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = False):
word_dropout=0, dropout=0, requires_grad: bool = True):
super().__init__()
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
@ -288,7 +288,7 @@ class _WordBertModel(nn.Module):
self.auto_truncate = auto_truncate
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
logger.info("Start to generating word pieces for word.")
logger.info("Start to generate word pieces for word.")
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
found_count = 0
@ -374,7 +374,8 @@ class _WordBertModel(nn.Module):
else:
raise RuntimeError(
"After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
# +2是由于需要加入[CLS]与[SEP]
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
@ -407,15 +408,26 @@ class _WordBertModel(nn.Module):
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
if self.include_cls_sep:
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
bert_outputs[-1].size(-1))
s_shift = 1
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
bert_outputs[-1].size(-1))
else:
s_shift = 0
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
bert_outputs[-1].size(-1))
s_shift = 0
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
if self.pool_method == 'first':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
elif self.pool_method == 'last':
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
for l_index, l in enumerate(self.layers):
output_layer = bert_outputs[l]
real_word_piece_length = output_layer.size(1) - 2
@ -426,16 +438,15 @@ class _WordBertModel(nn.Module):
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
# 从word_piece collapse到word的表示
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
outputs_seq_len = seq_len + s_shift
if self.pool_method == 'first':
for i in range(batch_size):
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
elif self.pool_method == 'last':
for i in range(batch_size):
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length]
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
elif self.pool_method == 'max':
for i in range(batch_size):
for j in range(seq_len[i]):
@ -452,5 +463,6 @@ class _WordBertModel(nn.Module):
else:
outputs[l_index, :, 0] = output_layer[:, 0]
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
# 3. 最终的embedding结果
return outputs

View File

@ -24,6 +24,7 @@ __all__ = [
'IMDBLoader',
'SSTLoader',
'SST2Loader',
"ChnSentiCorpLoader",
'ConllLoader',
'Conll2003Loader',
@ -52,8 +53,9 @@ __all__ = [
"SSTPipe",
"SST2Pipe",
"IMDBPipe",
"Conll2003Pipe",
"ChnSentiCorpPipe",
"Conll2003Pipe",
"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",

View File

@ -306,12 +306,15 @@ class DataBundle:
return self
def __repr__(self):
_str = 'In total {} datasets:\n'.format(len(self.datasets))
for name, dataset in self.datasets.items():
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
_str = ''
if len(self.datasets):
_str += 'In total {} datasets:\n'.format(len(self.datasets))
for name, dataset in self.datasets.items():
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
if len(self.vocabs):
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str

View File

@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = {
'cn-tencent': "tencent_cn.zip",
'cn-fasttext': "cc.zh.300.vec.gz",
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip',
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip",
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip",
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip"
}
DATASET_DIR = {
@ -96,7 +99,9 @@ DATASET_DIR = {
"cws-pku": 'cws_pku.zip',
"cws-cityu": "cws_cityu.zip",
"cws-as": 'cws_as.zip',
"cws-msra": 'cws_msra.zip'
"cws-msra": 'cws_msra.zip',
"chn-senti-corp":"chn_senti_corp.zip"
}
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,

View File

@ -52,6 +52,7 @@ __all__ = [
'IMDBLoader',
'SSTLoader',
'SST2Loader',
"ChnSentiCorpLoader",
'ConllLoader',
'Conll2003Loader',
@ -73,7 +74,7 @@ __all__ = [
"QNLILoader",
"RTELoader"
]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
from .csv import CSVLoader
from .cws import CWSLoader

View File

@ -7,6 +7,7 @@ __all__ = [
"IMDBLoader",
"SSTLoader",
"SST2Loader",
"ChnSentiCorpLoader"
]
import glob
@ -346,3 +347,59 @@ class SST2Loader(Loader):
"""
output_dir = self._get_dataset_path(dataset_name='sst-2')
return output_dir
class ChnSentiCorpLoader(Loader):
"""
支持读取的数据的格式为第一行为标题(具体内容会被忽略)之后一行为一个sample第一个制表符之前被认为是label
一个制表符及之后认为是句子
Example::
label raw_chars
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道...
0 商品的不足暂时还没发现京东的订单处理速度实在.......周二就打包完成周五才发货...
读取后的DataSet具有以下的field
.. csv-table::
:header: "raw_chars", "target"
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1"
"<荐书> 推荐所有喜欢<红楼>...", "1"
"..."
"""
def __init__(self):
super().__init__()
def _load(self, path:str):
"""
从path中读取数据
:param path:
:return:
"""
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline()
for line in f:
line = line.strip()
tab_index = line.index('\t')
if tab_index!=-1:
target = line[:tab_index]
raw_chars = line[tab_index+1:]
if raw_chars:
ds.append(Instance(raw_chars=raw_chars, target=target))
return ds
def download(self)->str:
"""
自动下载数据该数据取自https://github.com/pengming617/bert_classification/tree/master/data
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用
:return:
"""
output_dir = self._get_dataset_path('chn-senti-corp')
return output_dir

View File

@ -17,6 +17,7 @@ __all__ = [
"SSTPipe",
"SST2Pipe",
"IMDBPipe",
"ChnSentiCorpPipe",
"Conll2003NERPipe",
"OntoNotesNERPipe",
@ -39,7 +40,7 @@ __all__ = [
"MNLIPipe",
]
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe

View File

@ -5,7 +5,8 @@ __all__ = [
"YelpPolarityPipe",
"SSTPipe",
"SST2Pipe",
'IMDBPipe'
'IMDBPipe',
"ChnSentiCorpPipe"
]
import re
@ -13,18 +14,18 @@ import re
from nltk import Tree
from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field
from ..data_bundle import DataBundle
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import Vocabulary
from ..loader.classification import ChnSentiCorpLoader
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
class _CLSPipe(Pipe):
"""
分类问题的基类负责对classification的数据进行tokenize操作默认是对raw_words列操作然后生成words列
@ -457,3 +458,97 @@ class IMDBPipe(_CLSPipe):
data_bundle = self.process(data_bundle)
return data_bundle
class ChnSentiCorpPipe(Pipe):
"""
处理之后的DataSet有以下的结构
.. csv-table::
:header: "raw_chars", "chars", "target", "seq_len"
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25
"..."
其中chars, seq_len是inputtarget是target
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['', '', '', '', ...]->["复旦", "旦大", ...]如果
设置为True返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input对应的vocab可以通过
data_bundle.get_vocab('bigrams')获取.
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['', '', '', '', ...]->["复旦大", "旦大学", ...]
如果设置为True返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""
def __init__(self, bigrams=False, trigrams=False):
super().__init__()
self.bigrams = bigrams
self.trigrams = trigrams
def _tokenize(self, data_bundle):
"""
将DataSet中的"复旦大学"拆分为["", "", "", ""]. 未来可以通过扩展这个函数实现分词
:param data_bundle:
:return:
"""
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT)
return data_bundle
def process(self, data_bundle:DataBundle):
"""
可以处理的DataSet应该具备以下的field
.. csv-table::
:header: "raw_chars", "target"
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1"
"<荐书> 推荐所有喜欢<红楼>...", "1"
"..."
:param data_bundle:
:return:
"""
_add_chars_field(data_bundle, lower=False)
data_bundle = self._tokenize(data_bundle)
input_field_names = [Const.CHAR_INPUT]
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.iter_datasets():
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
# index
_indexize(data_bundle, input_field_names, Const.TARGET)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
return data_bundle
def process_from_file(self, paths=None):
"""
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数
:return: DataBundle
"""
# 读取数据
data_bundle = ChnSentiCorpLoader().load(paths)
data_bundle = self.process(data_bundle)
return data_bundle

View File

@ -222,14 +222,23 @@ class _CNNERPipe(Pipe):
target返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len
:param: str encoding_type: target列使用什么类型的encoding方式支持bioes, bio两种
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['', '', '', '', ...]->["复旦", "旦大", ...]如果
设置为True返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input对应的vocab可以通过
data_bundle.get_vocab('bigrams')获取.
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['', '', '', '', ...]->["复旦大", "旦大学", ...]
如果设置为True返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""
def __init__(self, encoding_type: str = 'bio'):
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = lambda words: iob2bioes(iob2(words))
self.bigrams = bigrams
self.trigrams = trigrams
def process(self, data_bundle: DataBundle) -> DataBundle:
"""
支持的DataSet的field为
@ -241,11 +250,11 @@ class _CNNERPipe(Pipe):
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
"[...]", "[...]"
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int]是转换为index的输入数据; target列是List[int]是转换为index的
target返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int]是转换为index的输入数据; target列是List[int]
是转换为index的target返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field且两个field的内容均为List[str]
在传入DataBundle基础上原位修改
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field且两个field
的内容均为List[str]在传入DataBundle基础上原位修改
:return: DataBundle
"""
# 转换tag
@ -253,11 +262,24 @@ class _CNNERPipe(Pipe):
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
_add_chars_field(data_bundle, lower=False)
input_field_names = [Const.CHAR_INPUT]
if self.bigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
# index
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
_indexize(data_bundle, input_field_names, Const.TARGET)
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items():

View File

@ -13,6 +13,12 @@ class TestDownload(unittest.TestCase):
words = torch.LongTensor([[2, 3, 4, 0]])
print(embed(words).size())
for pool_method in ['first', 'last', 'max', 'avg']:
for include_cls_sep in [True, False]:
embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method,
include_cls_sep=include_cls_sep)
print(embed(words).size())
def test_word_drop(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)

View File

@ -5,22 +5,22 @@ from fastNLP.io.loader.classification import YelpPolarityLoader
from fastNLP.io.loader.classification import IMDBLoader
from fastNLP.io.loader.classification import SST2Loader
from fastNLP.io.loader.classification import SSTLoader
from fastNLP.io.loader.classification import ChnSentiCorpLoader
import os
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
def test_download(self):
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
loader().download()
def test_load(self):
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
data_bundle = loader().load()
print(data_bundle)
class TestLoad(unittest.TestCase):
def test_load(self):
for loader in [IMDBLoader]:
data_bundle = loader().load('test/data_for_tests/io/imdb')

View File

@ -5,7 +5,7 @@ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNE
Conll2003Loader
class MSRANERTest(unittest.TestCase):
class TestMSRANER(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
MsraNERLoader().download(re_download=False)
@ -13,13 +13,13 @@ class MSRANERTest(unittest.TestCase):
print(data_bundle)
class PeopleDailyTest(unittest.TestCase):
class TestPeopleDaily(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
PeopleDailyNERLoader().download()
class WeiboNERTest(unittest.TestCase):
class TestWeiboNER(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
WeiboNERLoader().download()

View File

@ -3,7 +3,7 @@ import os
from fastNLP.io.loader import CWSLoader
class CWSLoaderTest(unittest.TestCase):
class TestCWSLoader(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_download(self):
dataset_names = ['pku', 'cityu', 'as', 'msra']
@ -13,7 +13,7 @@ class CWSLoaderTest(unittest.TestCase):
print(data_bundle)
class RunCWSLoaderTest(unittest.TestCase):
class TestRunCWSLoader(unittest.TestCase):
def test_cws_loader(self):
dataset_names = ['msra']
for dataset_name in dataset_names:

View File

@ -8,7 +8,7 @@ from fastNLP.io.loader.matching import MNLILoader
import os
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
class TestMatchingDownload(unittest.TestCase):
def test_download(self):
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
loader().download()
@ -21,8 +21,7 @@ class TestDownload(unittest.TestCase):
print(data_bundle)
class TestLoad(unittest.TestCase):
class TestMatchingLoad(unittest.TestCase):
def test_load(self):
for loader in [RTELoader]:
data_bundle = loader().load('test/data_for_tests/io/rte')

View File

@ -2,9 +2,10 @@ import unittest
import os
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe
from fastNLP.io.pipe.classification import ChnSentiCorpPipe
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
class TestClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
with self.subTest(pipe=pipe):
@ -14,8 +15,16 @@ class TestPipe(unittest.TestCase):
class TestRunPipe(unittest.TestCase):
def test_load(self):
for pipe in [IMDBPipe]:
data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb')
print(data_bundle)
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestCNClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [ChnSentiCorpPipe]:
with self.subTest(pipe=pipe):
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
print(data_bundle)

View File

@ -4,12 +4,14 @@ from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, Conll2003Pipe
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
class TestConllPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe().process_from_file()
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
print(data_bundle)
data_bundle = pipe(encoding_type='bioes').process_from_file()
print(data_bundle)

View File

@ -4,7 +4,7 @@ import os
from fastNLP.io.pipe.cws import CWSPipe
class CWSPipeTest(unittest.TestCase):
class TestCWSPipe(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_process_from_file(self):
dataset_names = ['pku', 'cityu', 'as', 'msra']
@ -14,7 +14,7 @@ class CWSPipeTest(unittest.TestCase):
print(data_bundle)
class RunCWSPipeTest(unittest.TestCase):
class TestRunCWSPipe(unittest.TestCase):
def test_process_from_file(self):
dataset_names = ['msra']
for dataset_name in dataset_names:

View File

@ -7,7 +7,7 @@ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MN
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
class TestMatchingPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
with self.subTest(pipe=pipe):
@ -17,7 +17,7 @@ class TestPipe(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestBertPipe(unittest.TestCase):
class TestMatchingBertPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
with self.subTest(pipe=pipe):
@ -26,7 +26,7 @@ class TestBertPipe(unittest.TestCase):
print(data_bundle)
class TestRunPipe(unittest.TestCase):
class TestRunMatchingPipe(unittest.TestCase):
def test_load(self):
for pipe in [RTEPipe, RTEBertPipe]: