mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
增加中文分类Pipe;使用矩阵加速BertEmbedding部分pool_method;调整部分测试用例名称;修复metric中对warning的误报
This commit is contained in:
parent
d15ad75d96
commit
e903db0e70
@ -238,8 +238,8 @@ class CrossEntropyLoss(LossBase):
|
||||
pred = pred.tranpose(-1, pred)
|
||||
pred = pred.reshape(-1, pred.size(-1))
|
||||
target = target.reshape(-1)
|
||||
if seq_len is not None:
|
||||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0)
|
||||
if seq_len is not None and target.dim()>1:
|
||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0)
|
||||
target = target.masked_fill(mask, self.padding_idx)
|
||||
|
||||
return F.cross_entropy(input=pred, target=target,
|
||||
|
@ -347,7 +347,7 @@ class AccuracyMetric(MetricBase):
|
||||
pass
|
||||
elif pred.dim() == target.dim() + 1:
|
||||
pred = pred.argmax(dim=-1)
|
||||
if seq_len is None:
|
||||
if seq_len is None and target.dim()>1:
|
||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
|
||||
else:
|
||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
|
||||
|
@ -68,7 +68,7 @@ class BertEmbedding(ContextualEmbedding):
|
||||
|
||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
|
||||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
|
||||
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False):
|
||||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False):
|
||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||
|
||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
|
||||
@ -165,7 +165,7 @@ class BertWordPieceEncoder(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
|
||||
word_dropout=0, dropout=0, requires_grad: bool = False):
|
||||
word_dropout=0, dropout=0, requires_grad: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
|
||||
@ -288,7 +288,7 @@ class _WordBertModel(nn.Module):
|
||||
self.auto_truncate = auto_truncate
|
||||
|
||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
|
||||
logger.info("Start to generating word pieces for word.")
|
||||
logger.info("Start to generate word pieces for word.")
|
||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
|
||||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
|
||||
found_count = 0
|
||||
@ -374,7 +374,8 @@ class _WordBertModel(nn.Module):
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"After split words into word pieces, the lengths of word pieces are longer than the "
|
||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")
|
||||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set "
|
||||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.")
|
||||
|
||||
# +2是由于需要加入[CLS]与[SEP]
|
||||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
|
||||
@ -407,15 +408,26 @@ class _WordBertModel(nn.Module):
|
||||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
|
||||
|
||||
if self.include_cls_sep:
|
||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
|
||||
bert_outputs[-1].size(-1))
|
||||
s_shift = 1
|
||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
|
||||
bert_outputs[-1].size(-1))
|
||||
|
||||
else:
|
||||
s_shift = 0
|
||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
|
||||
bert_outputs[-1].size(-1))
|
||||
s_shift = 0
|
||||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
|
||||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
|
||||
|
||||
if self.pool_method == 'first':
|
||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()]
|
||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
|
||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
elif self.pool_method == 'last':
|
||||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1
|
||||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0)
|
||||
batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1)))
|
||||
|
||||
for l_index, l in enumerate(self.layers):
|
||||
output_layer = bert_outputs[l]
|
||||
real_word_piece_length = output_layer.size(1) - 2
|
||||
@ -426,16 +438,15 @@ class _WordBertModel(nn.Module):
|
||||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
|
||||
# 从word_piece collapse到word的表示
|
||||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
|
||||
outputs_seq_len = seq_len + s_shift
|
||||
if self.pool_method == 'first':
|
||||
for i in range(batch_size):
|
||||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置
|
||||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[
|
||||
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size
|
||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
|
||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
|
||||
|
||||
elif self.pool_method == 'last':
|
||||
for i in range(batch_size):
|
||||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end
|
||||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length]
|
||||
tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length]
|
||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
|
||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
|
||||
elif self.pool_method == 'max':
|
||||
for i in range(batch_size):
|
||||
for j in range(seq_len[i]):
|
||||
@ -452,5 +463,6 @@ class _WordBertModel(nn.Module):
|
||||
else:
|
||||
outputs[l_index, :, 0] = output_layer[:, 0]
|
||||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
|
||||
|
||||
# 3. 最终的embedding结果
|
||||
return outputs
|
||||
|
@ -24,6 +24,7 @@ __all__ = [
|
||||
'IMDBLoader',
|
||||
'SSTLoader',
|
||||
'SST2Loader',
|
||||
"ChnSentiCorpLoader",
|
||||
|
||||
'ConllLoader',
|
||||
'Conll2003Loader',
|
||||
@ -52,8 +53,9 @@ __all__ = [
|
||||
"SSTPipe",
|
||||
"SST2Pipe",
|
||||
"IMDBPipe",
|
||||
"Conll2003Pipe",
|
||||
"ChnSentiCorpPipe",
|
||||
|
||||
"Conll2003Pipe",
|
||||
"Conll2003NERPipe",
|
||||
"OntoNotesNERPipe",
|
||||
"MsraNERPipe",
|
||||
|
@ -306,12 +306,15 @@ class DataBundle:
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
_str = 'In total {} datasets:\n'.format(len(self.datasets))
|
||||
for name, dataset in self.datasets.items():
|
||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
|
||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
|
||||
for name, vocab in self.vocabs.items():
|
||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
|
||||
_str = ''
|
||||
if len(self.datasets):
|
||||
_str += 'In total {} datasets:\n'.format(len(self.datasets))
|
||||
for name, dataset in self.datasets.items():
|
||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
|
||||
if len(self.vocabs):
|
||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
|
||||
for name, vocab in self.vocabs.items():
|
||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
|
||||
return _str
|
||||
|
||||
|
||||
|
@ -77,6 +77,9 @@ PRETRAIN_STATIC_FILES = {
|
||||
'cn-tencent': "tencent_cn.zip",
|
||||
'cn-fasttext': "cc.zh.300.vec.gz",
|
||||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip',
|
||||
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip",
|
||||
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip",
|
||||
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip"
|
||||
}
|
||||
|
||||
DATASET_DIR = {
|
||||
@ -96,7 +99,9 @@ DATASET_DIR = {
|
||||
"cws-pku": 'cws_pku.zip',
|
||||
"cws-cityu": "cws_cityu.zip",
|
||||
"cws-as": 'cws_as.zip',
|
||||
"cws-msra": 'cws_msra.zip'
|
||||
"cws-msra": 'cws_msra.zip',
|
||||
|
||||
"chn-senti-corp":"chn_senti_corp.zip"
|
||||
}
|
||||
|
||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
|
||||
|
@ -52,6 +52,7 @@ __all__ = [
|
||||
'IMDBLoader',
|
||||
'SSTLoader',
|
||||
'SST2Loader',
|
||||
"ChnSentiCorpLoader",
|
||||
|
||||
'ConllLoader',
|
||||
'Conll2003Loader',
|
||||
@ -73,7 +74,7 @@ __all__ = [
|
||||
"QNLILoader",
|
||||
"RTELoader"
|
||||
]
|
||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader
|
||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader
|
||||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
|
||||
from .csv import CSVLoader
|
||||
from .cws import CWSLoader
|
||||
|
@ -7,6 +7,7 @@ __all__ = [
|
||||
"IMDBLoader",
|
||||
"SSTLoader",
|
||||
"SST2Loader",
|
||||
"ChnSentiCorpLoader"
|
||||
]
|
||||
|
||||
import glob
|
||||
@ -346,3 +347,59 @@ class SST2Loader(Loader):
|
||||
"""
|
||||
output_dir = self._get_dataset_path(dataset_name='sst-2')
|
||||
return output_dir
|
||||
|
||||
|
||||
class ChnSentiCorpLoader(Loader):
|
||||
"""
|
||||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第
|
||||
一个制表符及之后认为是句子
|
||||
|
||||
Example::
|
||||
|
||||
label raw_chars
|
||||
1 這間酒店環境和服務態度亦算不錯,但房間空間太小~~
|
||||
1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道...
|
||||
0 商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货...
|
||||
|
||||
读取后的DataSet具有以下的field
|
||||
|
||||
.. csv-table::
|
||||
:header: "raw_chars", "target"
|
||||
|
||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1"
|
||||
"<荐书> 推荐所有喜欢<红楼>...", "1"
|
||||
"..."
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _load(self, path:str):
|
||||
"""
|
||||
从path中读取数据
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
ds = DataSet()
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
f.readline()
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
tab_index = line.index('\t')
|
||||
if tab_index!=-1:
|
||||
target = line[:tab_index]
|
||||
raw_chars = line[tab_index+1:]
|
||||
if raw_chars:
|
||||
ds.append(Instance(raw_chars=raw_chars, target=target))
|
||||
return ds
|
||||
|
||||
def download(self)->str:
|
||||
"""
|
||||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在
|
||||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用
|
||||
|
||||
:return:
|
||||
"""
|
||||
output_dir = self._get_dataset_path('chn-senti-corp')
|
||||
return output_dir
|
||||
|
@ -17,6 +17,7 @@ __all__ = [
|
||||
"SSTPipe",
|
||||
"SST2Pipe",
|
||||
"IMDBPipe",
|
||||
"ChnSentiCorpPipe",
|
||||
|
||||
"Conll2003NERPipe",
|
||||
"OntoNotesNERPipe",
|
||||
@ -39,7 +40,7 @@ __all__ = [
|
||||
"MNLIPipe",
|
||||
]
|
||||
|
||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
|
||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe
|
||||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
|
||||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
|
||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
|
||||
|
@ -5,7 +5,8 @@ __all__ = [
|
||||
"YelpPolarityPipe",
|
||||
"SSTPipe",
|
||||
"SST2Pipe",
|
||||
'IMDBPipe'
|
||||
'IMDBPipe',
|
||||
"ChnSentiCorpPipe"
|
||||
]
|
||||
|
||||
import re
|
||||
@ -13,18 +14,18 @@ import re
|
||||
from nltk import Tree
|
||||
|
||||
from .pipe import Pipe
|
||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
|
||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field
|
||||
from ..data_bundle import DataBundle
|
||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
|
||||
from ...core.const import Const
|
||||
from ...core.dataset import DataSet
|
||||
from ...core.instance import Instance
|
||||
from ...core.vocabulary import Vocabulary
|
||||
from ..loader.classification import ChnSentiCorpLoader
|
||||
|
||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
|
||||
|
||||
|
||||
|
||||
class _CLSPipe(Pipe):
|
||||
"""
|
||||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列
|
||||
@ -457,3 +458,97 @@ class IMDBPipe(_CLSPipe):
|
||||
data_bundle = self.process(data_bundle)
|
||||
|
||||
return data_bundle
|
||||
|
||||
|
||||
class ChnSentiCorpPipe(Pipe):
|
||||
"""
|
||||
处理之后的DataSet有以下的结构
|
||||
|
||||
.. csv-table::
|
||||
:header: "raw_chars", "chars", "target", "seq_len"
|
||||
|
||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "[2, 3, 4, 5, ...]", 1, 31
|
||||
"<荐书> 推荐所有喜欢<红楼>...", "[10, 21, ....]", 1, 25
|
||||
"..."
|
||||
|
||||
其中chars, seq_len是input,target是target
|
||||
|
||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果
|
||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
|
||||
data_bundle.get_vocab('bigrams')获取.
|
||||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
|
||||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
|
||||
data_bundle.get_vocab('trigrams')获取.
|
||||
"""
|
||||
def __init__(self, bigrams=False, trigrams=False):
|
||||
super().__init__()
|
||||
|
||||
self.bigrams = bigrams
|
||||
self.trigrams = trigrams
|
||||
|
||||
def _tokenize(self, data_bundle):
|
||||
"""
|
||||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。
|
||||
|
||||
:param data_bundle:
|
||||
:return:
|
||||
"""
|
||||
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT)
|
||||
return data_bundle
|
||||
|
||||
def process(self, data_bundle:DataBundle):
|
||||
"""
|
||||
可以处理的DataSet应该具备以下的field
|
||||
|
||||
.. csv-table::
|
||||
:header: "raw_chars", "target"
|
||||
|
||||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1"
|
||||
"<荐书> 推荐所有喜欢<红楼>...", "1"
|
||||
"..."
|
||||
|
||||
:param data_bundle:
|
||||
:return:
|
||||
"""
|
||||
_add_chars_field(data_bundle, lower=False)
|
||||
|
||||
data_bundle = self._tokenize(data_bundle)
|
||||
|
||||
input_field_names = [Const.CHAR_INPUT]
|
||||
if self.bigrams:
|
||||
for name, dataset in data_bundle.iter_datasets():
|
||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
|
||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
|
||||
input_field_names.append('bigrams')
|
||||
if self.trigrams:
|
||||
for name, dataset in data_bundle.iter_datasets():
|
||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
|
||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
|
||||
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
|
||||
input_field_names.append('trigrams')
|
||||
|
||||
# index
|
||||
_indexize(data_bundle, input_field_names, Const.TARGET)
|
||||
|
||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
|
||||
target_fields = [Const.TARGET]
|
||||
|
||||
for name, dataset in data_bundle.datasets.items():
|
||||
dataset.add_seq_len(Const.CHAR_INPUT)
|
||||
|
||||
data_bundle.set_input(*input_fields)
|
||||
data_bundle.set_target(*target_fields)
|
||||
|
||||
return data_bundle
|
||||
|
||||
def process_from_file(self, paths=None):
|
||||
"""
|
||||
|
||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
|
||||
:return: DataBundle
|
||||
"""
|
||||
# 读取数据
|
||||
data_bundle = ChnSentiCorpLoader().load(paths)
|
||||
data_bundle = self.process(data_bundle)
|
||||
|
||||
return data_bundle
|
@ -222,14 +222,23 @@ class _CNNERPipe(Pipe):
|
||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。
|
||||
|
||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
|
||||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果
|
||||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
|
||||
data_bundle.get_vocab('bigrams')获取.
|
||||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
|
||||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
|
||||
data_bundle.get_vocab('trigrams')获取.
|
||||
"""
|
||||
|
||||
def __init__(self, encoding_type: str = 'bio'):
|
||||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False):
|
||||
if encoding_type == 'bio':
|
||||
self.convert_tag = iob2
|
||||
else:
|
||||
self.convert_tag = lambda words: iob2bioes(iob2(words))
|
||||
|
||||
|
||||
self.bigrams = bigrams
|
||||
self.trigrams = trigrams
|
||||
|
||||
def process(self, data_bundle: DataBundle) -> DataBundle:
|
||||
"""
|
||||
支持的DataSet的field为
|
||||
@ -241,11 +250,11 @@ class _CNNERPipe(Pipe):
|
||||
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
|
||||
"[...]", "[...]"
|
||||
|
||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
|
||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
|
||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],
|
||||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
|
||||
|
||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
|
||||
在传入DataBundle基础上原位修改。
|
||||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field
|
||||
的内容均为List[str]。在传入DataBundle基础上原位修改。
|
||||
:return: DataBundle
|
||||
"""
|
||||
# 转换tag
|
||||
@ -253,11 +262,24 @@ class _CNNERPipe(Pipe):
|
||||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
|
||||
|
||||
_add_chars_field(data_bundle, lower=False)
|
||||
|
||||
|
||||
input_field_names = [Const.CHAR_INPUT]
|
||||
if self.bigrams:
|
||||
for name, dataset in data_bundle.datasets.items():
|
||||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
|
||||
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
|
||||
input_field_names.append('bigrams')
|
||||
if self.trigrams:
|
||||
for name, dataset in data_bundle.datasets.items():
|
||||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
|
||||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
|
||||
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
|
||||
input_field_names.append('trigrams')
|
||||
|
||||
# index
|
||||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
|
||||
_indexize(data_bundle, input_field_names, Const.TARGET)
|
||||
|
||||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
|
||||
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
|
||||
target_fields = [Const.TARGET, Const.INPUT_LEN]
|
||||
|
||||
for name, dataset in data_bundle.datasets.items():
|
||||
|
@ -13,6 +13,12 @@ class TestDownload(unittest.TestCase):
|
||||
words = torch.LongTensor([[2, 3, 4, 0]])
|
||||
print(embed(words).size())
|
||||
|
||||
for pool_method in ['first', 'last', 'max', 'avg']:
|
||||
for include_cls_sep in [True, False]:
|
||||
embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method,
|
||||
include_cls_sep=include_cls_sep)
|
||||
print(embed(words).size())
|
||||
|
||||
def test_word_drop(self):
|
||||
vocab = Vocabulary().add_word_lst("This is a test .".split())
|
||||
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)
|
||||
|
@ -5,22 +5,22 @@ from fastNLP.io.loader.classification import YelpPolarityLoader
|
||||
from fastNLP.io.loader.classification import IMDBLoader
|
||||
from fastNLP.io.loader.classification import SST2Loader
|
||||
from fastNLP.io.loader.classification import SSTLoader
|
||||
from fastNLP.io.loader.classification import ChnSentiCorpLoader
|
||||
import os
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestDownload(unittest.TestCase):
|
||||
def test_download(self):
|
||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
|
||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
|
||||
loader().download()
|
||||
|
||||
def test_load(self):
|
||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
|
||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
|
||||
data_bundle = loader().load()
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class TestLoad(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
for loader in [IMDBLoader]:
|
||||
data_bundle = loader().load('test/data_for_tests/io/imdb')
|
||||
|
@ -5,7 +5,7 @@ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNE
|
||||
Conll2003Loader
|
||||
|
||||
|
||||
class MSRANERTest(unittest.TestCase):
|
||||
class TestMSRANER(unittest.TestCase):
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
def test_download(self):
|
||||
MsraNERLoader().download(re_download=False)
|
||||
@ -13,13 +13,13 @@ class MSRANERTest(unittest.TestCase):
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class PeopleDailyTest(unittest.TestCase):
|
||||
class TestPeopleDaily(unittest.TestCase):
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
def test_download(self):
|
||||
PeopleDailyNERLoader().download()
|
||||
|
||||
|
||||
class WeiboNERTest(unittest.TestCase):
|
||||
class TestWeiboNER(unittest.TestCase):
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
def test_download(self):
|
||||
WeiboNERLoader().download()
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
from fastNLP.io.loader import CWSLoader
|
||||
|
||||
|
||||
class CWSLoaderTest(unittest.TestCase):
|
||||
class TestCWSLoader(unittest.TestCase):
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
def test_download(self):
|
||||
dataset_names = ['pku', 'cityu', 'as', 'msra']
|
||||
@ -13,7 +13,7 @@ class CWSLoaderTest(unittest.TestCase):
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class RunCWSLoaderTest(unittest.TestCase):
|
||||
class TestRunCWSLoader(unittest.TestCase):
|
||||
def test_cws_loader(self):
|
||||
dataset_names = ['msra']
|
||||
for dataset_name in dataset_names:
|
||||
|
@ -8,7 +8,7 @@ from fastNLP.io.loader.matching import MNLILoader
|
||||
import os
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestDownload(unittest.TestCase):
|
||||
class TestMatchingDownload(unittest.TestCase):
|
||||
def test_download(self):
|
||||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
|
||||
loader().download()
|
||||
@ -21,8 +21,7 @@ class TestDownload(unittest.TestCase):
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class TestLoad(unittest.TestCase):
|
||||
|
||||
class TestMatchingLoad(unittest.TestCase):
|
||||
def test_load(self):
|
||||
for loader in [RTELoader]:
|
||||
data_bundle = loader().load('test/data_for_tests/io/rte')
|
||||
|
@ -2,9 +2,10 @@ import unittest
|
||||
import os
|
||||
|
||||
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe
|
||||
from fastNLP.io.pipe.classification import ChnSentiCorpPipe
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestPipe(unittest.TestCase):
|
||||
class TestClassificationPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
|
||||
with self.subTest(pipe=pipe):
|
||||
@ -14,8 +15,16 @@ class TestPipe(unittest.TestCase):
|
||||
|
||||
|
||||
class TestRunPipe(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
for pipe in [IMDBPipe]:
|
||||
data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb')
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestCNClassificationPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
for pipe in [ChnSentiCorpPipe]:
|
||||
with self.subTest(pipe=pipe):
|
||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
|
||||
print(data_bundle)
|
@ -4,12 +4,14 @@ from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, Conll2003Pipe
|
||||
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestPipe(unittest.TestCase):
|
||||
class TestConllPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]:
|
||||
with self.subTest(pipe=pipe):
|
||||
print(pipe)
|
||||
data_bundle = pipe().process_from_file()
|
||||
data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
|
||||
print(data_bundle)
|
||||
data_bundle = pipe(encoding_type='bioes').process_from_file()
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
from fastNLP.io.pipe.cws import CWSPipe
|
||||
|
||||
|
||||
class CWSPipeTest(unittest.TestCase):
|
||||
class TestCWSPipe(unittest.TestCase):
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
def test_process_from_file(self):
|
||||
dataset_names = ['pku', 'cityu', 'as', 'msra']
|
||||
@ -14,7 +14,7 @@ class CWSPipeTest(unittest.TestCase):
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class RunCWSPipeTest(unittest.TestCase):
|
||||
class TestRunCWSPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
dataset_names = ['msra']
|
||||
for dataset_name in dataset_names:
|
||||
|
@ -7,7 +7,7 @@ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MN
|
||||
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestPipe(unittest.TestCase):
|
||||
class TestMatchingPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
|
||||
with self.subTest(pipe=pipe):
|
||||
@ -17,7 +17,7 @@ class TestPipe(unittest.TestCase):
|
||||
|
||||
|
||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
|
||||
class TestBertPipe(unittest.TestCase):
|
||||
class TestMatchingBertPipe(unittest.TestCase):
|
||||
def test_process_from_file(self):
|
||||
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
|
||||
with self.subTest(pipe=pipe):
|
||||
@ -26,7 +26,7 @@ class TestBertPipe(unittest.TestCase):
|
||||
print(data_bundle)
|
||||
|
||||
|
||||
class TestRunPipe(unittest.TestCase):
|
||||
class TestRunMatchingPipe(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
for pipe in [RTEPipe, RTEBertPipe]:
|
||||
|
Loading…
Reference in New Issue
Block a user