完善测试

This commit is contained in:
yh_cc 2019-05-09 00:15:25 +08:00
parent 25565fe0c9
commit 6d36dbe7fb
7 changed files with 78 additions and 35 deletions

View File

@ -8,6 +8,6 @@ from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequen
BertForTokenClassification
from .biaffine_parser import BiaffineParser, GraphParser
from .cnn_text_classification import CNNText
from .sequence_modeling import SeqLabeling, AdvSeqLabel
from .sequence_labeling import SeqLabeling, AdvSeqLabel
from .snli import ESIM
from .star_transformer import STSeqCls, STNLICls, STSeqLabel

View File

@ -43,7 +43,7 @@ class SeqLabeling(BaseModel):
x = self.Embedding(words)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
x,_ = self.Rnn(x, seq_len)
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
@ -55,13 +55,13 @@ class SeqLabeling(BaseModel):
:param torch.LongTensor words: [batch_size, max_len]
:param torch.LongTensor seq_len: [batch_size,]
:return:
:return: {'pred': xx}, [batch_size, max_len]
"""
self.mask = self._make_mask(words, seq_len)
x = self.Embedding(words)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
x, _ = self.Rnn(x, seq_len)
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
@ -93,13 +93,13 @@ class SeqLabeling(BaseModel):
def _decode(self, x):
"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:return prediction: list of [decode path(list)]
:return prediction: [batch_size, max_len]
"""
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
return tag_seq
class AdvSeqLabel:
class AdvSeqLabel(nn.Module):
"""
更复杂的Sequence Labelling模型结构为Embedding, LayerNorm, 双向LSTM(两层)FCLayerNormDropOutFCCRF
"""
@ -115,17 +115,19 @@ class AdvSeqLabel:
:param dict id2words: tag id转为其tag word的表用于在CRF解码时防止解出非法的顺序比如'BMES'这个标签规范中'S'
不能出现在'B'之后这里也支持类似与'B-NN''-'前为标签类型的指示后面为具体的tag的情况这里不但会保证
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN''E-NN'的情况)
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况游泳
"""
super().__init__()
self.Embedding = encoder.embedding.Embedding(init_embed)
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim)
self.Rnn = torch.nn.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout,
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout,
bidirectional=True, batch_first=True)
self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3)
self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3)
self.relu = torch.nn.LeakyReLU()
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes)
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes)
if id2words is None:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
@ -137,9 +139,9 @@ class AdvSeqLabel:
def _decode(self, x):
"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:return prediction: list of [decode path(list)]
:return torch.LongTensor, [batch_size, max_len]
"""
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
return tag_seq
def _internal_loss(self, x, y):
@ -176,31 +178,20 @@ class AdvSeqLabel:
words = words.long()
seq_len = seq_len.long()
self.mask = self._make_mask(words, seq_len)
sent_len, idx_sort = torch.sort(seq_len, descending=True)
_, idx_unsort = torch.sort(idx_sort, descending=False)
# seq_len = seq_len.long()
target = target.long() if target is not None else None
if next(self.parameters()).is_cuda:
words = words.cuda()
idx_sort = idx_sort.cuda()
idx_unsort = idx_unsort.cuda()
self.mask = self.mask.cuda()
x = self.Embedding(words)
x = self.norm1(x)
# [batch_size, max_len, word_emb_dim]
sent_variable = x[idx_sort]
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)
x, _ = self.Rnn(x, seq_len=seq_len)
x, _ = self.Rnn(sent_packed)
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
x = sent_output[idx_unsort]
x = x.contiguous()
x = self.Linear1(x)
x = self.norm2(x)
x = self.relu(x)
@ -225,6 +216,7 @@ class AdvSeqLabel:
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len:[batch_size, ]
:return: [list1, list2, ...], 内部每个list为一个路径已经unpad了
:return {'pred':}, value是torch.LongTensor, [batch_size, max_len]
"""
return self._forward(words, seq_len, )
return self._forward(words, seq_len)

View File

@ -13,7 +13,7 @@ from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInp
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.trainer import Trainer
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.models.sequence_labeling import AdvSeqLabel
from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor

View File

@ -0,0 +1,7 @@
5 50
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796

View File

@ -3,7 +3,9 @@ import numpy as np
from fastNLP import Vocabulary
from fastNLP.io import EmbedLoader
import os
from fastNLP.io.dataset_loader import SSTLoader
from fastNLP.core.const import Const as C
class TestEmbedLoader(unittest.TestCase):
def test_load_with_vocab(self):
@ -36,4 +38,14 @@ class TestEmbedLoader(unittest.TestCase):
self.assertEqual(w_m.shape, (7, 50))
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
for word in words:
self.assertIn(word, vocab)
self.assertIn(word, vocab)
def test_read_all_glove(self):
pass
# 这是可以运行的但是总数少于行数应该是由于glove有重复的word
# path = '/where/to/read/full/glove'
# init_embed, vocab = EmbedLoader.load_without_vocab(path, error='strict')
# print(init_embed.shape)
# print(init_embed.mean())
# print(np.isnan(init_embed).sum())
# print(len(vocab))

View File

@ -1,7 +1,7 @@
import unittest
from test.models.model_runner import *
from .model_runner import *
from fastNLP.models.cnn_text_classification import CNNText
@ -16,7 +16,3 @@ class TestCNNText(unittest.TestCase):
padding=0,
dropout=0.5)
RUNNER.run_model_with_task(TEXT_CLS, model)
if __name__ == '__main__':
TestCNNText().test_case1()

View File

@ -0,0 +1,36 @@
import unittest
from .model_runner import *
from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel
from fastNLP.core.losses import LossInForward
class TesSeqLabel(unittest.TestCase):
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = SeqLabeling(init_emb,
hidden_size=30,
num_classes=NUM_CLS)
data = RUNNER.prepare_pos_tagging_data()
data.set_input('target')
loss = LossInForward()
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
RUNNER.run_model(model, data, loss, metric)
class TesAdvSeqLabel(unittest.TestCase):
def test_case1(self):
# 测试能否正常运行CNN
init_emb = (VOCAB_SIZE, 30)
model = AdvSeqLabel(init_emb,
hidden_size=30,
num_classes=NUM_CLS)
data = RUNNER.prepare_pos_tagging_data()
data.set_input('target')
loss = LossInForward()
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
RUNNER.run_model(model, data, loss, metric)