mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
完善测试
This commit is contained in:
parent
25565fe0c9
commit
6d36dbe7fb
@ -8,6 +8,6 @@ from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequen
|
||||
BertForTokenClassification
|
||||
from .biaffine_parser import BiaffineParser, GraphParser
|
||||
from .cnn_text_classification import CNNText
|
||||
from .sequence_modeling import SeqLabeling, AdvSeqLabel
|
||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel
|
||||
from .snli import ESIM
|
||||
from .star_transformer import STSeqCls, STNLICls, STSeqLabel
|
||||
|
@ -43,7 +43,7 @@ class SeqLabeling(BaseModel):
|
||||
|
||||
x = self.Embedding(words)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x = self.Rnn(x)
|
||||
x,_ = self.Rnn(x, seq_len)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
x = self.Linear(x)
|
||||
# [batch_size, max_len, num_classes]
|
||||
@ -55,13 +55,13 @@ class SeqLabeling(BaseModel):
|
||||
|
||||
:param torch.LongTensor words: [batch_size, max_len]
|
||||
:param torch.LongTensor seq_len: [batch_size,]
|
||||
:return:
|
||||
:return: {'pred': xx}, [batch_size, max_len]
|
||||
"""
|
||||
self.mask = self._make_mask(words, seq_len)
|
||||
|
||||
x = self.Embedding(words)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
x = self.Rnn(x)
|
||||
x, _ = self.Rnn(x, seq_len)
|
||||
# [batch_size, max_len, hidden_size * direction]
|
||||
x = self.Linear(x)
|
||||
# [batch_size, max_len, num_classes]
|
||||
@ -93,13 +93,13 @@ class SeqLabeling(BaseModel):
|
||||
def _decode(self, x):
|
||||
"""
|
||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
|
||||
:return prediction: list of [decode path(list)]
|
||||
:return prediction: [batch_size, max_len]
|
||||
"""
|
||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
|
||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
|
||||
return tag_seq
|
||||
|
||||
|
||||
class AdvSeqLabel:
|
||||
class AdvSeqLabel(nn.Module):
|
||||
"""
|
||||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。
|
||||
"""
|
||||
@ -115,17 +115,19 @@ class AdvSeqLabel:
|
||||
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S'
|
||||
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证
|
||||
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。)
|
||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。
|
||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况游泳。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.Embedding = encoder.embedding.Embedding(init_embed)
|
||||
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim)
|
||||
self.Rnn = torch.nn.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout,
|
||||
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout,
|
||||
bidirectional=True, batch_first=True)
|
||||
self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3)
|
||||
self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3)
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3)
|
||||
self.relu = torch.nn.LeakyReLU()
|
||||
self.drop = torch.nn.Dropout(dropout)
|
||||
self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes)
|
||||
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes)
|
||||
|
||||
if id2words is None:
|
||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
|
||||
@ -137,9 +139,9 @@ class AdvSeqLabel:
|
||||
def _decode(self, x):
|
||||
"""
|
||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
|
||||
:return prediction: list of [decode path(list)]
|
||||
:return torch.LongTensor, [batch_size, max_len]
|
||||
"""
|
||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
|
||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
|
||||
return tag_seq
|
||||
|
||||
def _internal_loss(self, x, y):
|
||||
@ -176,31 +178,20 @@ class AdvSeqLabel:
|
||||
words = words.long()
|
||||
seq_len = seq_len.long()
|
||||
self.mask = self._make_mask(words, seq_len)
|
||||
sent_len, idx_sort = torch.sort(seq_len, descending=True)
|
||||
_, idx_unsort = torch.sort(idx_sort, descending=False)
|
||||
|
||||
# seq_len = seq_len.long()
|
||||
target = target.long() if target is not None else None
|
||||
|
||||
if next(self.parameters()).is_cuda:
|
||||
words = words.cuda()
|
||||
idx_sort = idx_sort.cuda()
|
||||
idx_unsort = idx_unsort.cuda()
|
||||
self.mask = self.mask.cuda()
|
||||
|
||||
x = self.Embedding(words)
|
||||
x = self.norm1(x)
|
||||
# [batch_size, max_len, word_emb_dim]
|
||||
|
||||
sent_variable = x[idx_sort]
|
||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)
|
||||
x, _ = self.Rnn(x, seq_len=seq_len)
|
||||
|
||||
x, _ = self.Rnn(sent_packed)
|
||||
|
||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
|
||||
x = sent_output[idx_unsort]
|
||||
|
||||
x = x.contiguous()
|
||||
x = self.Linear1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.relu(x)
|
||||
@ -225,6 +216,7 @@ class AdvSeqLabel:
|
||||
|
||||
:param torch.LongTensor words: [batch_size, mex_len]
|
||||
:param torch.LongTensor seq_len:[batch_size, ]
|
||||
:return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。
|
||||
:return {'pred':}, value是torch.LongTensor, [batch_size, max_len]
|
||||
|
||||
"""
|
||||
return self._forward(words, seq_len, )
|
||||
return self._forward(words, seq_len)
|
@ -13,7 +13,7 @@ from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInp
|
||||
from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.models.sequence_labeling import AdvSeqLabel
|
||||
from fastNLP.io.dataset_loader import ConllxDataLoader
|
||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor
|
||||
|
||||
|
7
test/data_for_tests/word2vec_test.txt
Normal file
7
test/data_for_tests/word2vec_test.txt
Normal file
@ -0,0 +1,7 @@
|
||||
5 50
|
||||
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581
|
||||
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375
|
||||
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044
|
||||
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097
|
||||
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285
|
||||
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796
|
@ -3,7 +3,9 @@ import numpy as np
|
||||
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP.io import EmbedLoader
|
||||
|
||||
import os
|
||||
from fastNLP.io.dataset_loader import SSTLoader
|
||||
from fastNLP.core.const import Const as C
|
||||
|
||||
class TestEmbedLoader(unittest.TestCase):
|
||||
def test_load_with_vocab(self):
|
||||
@ -36,4 +38,14 @@ class TestEmbedLoader(unittest.TestCase):
|
||||
self.assertEqual(w_m.shape, (7, 50))
|
||||
self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
|
||||
for word in words:
|
||||
self.assertIn(word, vocab)
|
||||
self.assertIn(word, vocab)
|
||||
|
||||
def test_read_all_glove(self):
|
||||
pass
|
||||
# 这是可以运行的,但是总数少于行数,应该是由于glove有重复的word
|
||||
# path = '/where/to/read/full/glove'
|
||||
# init_embed, vocab = EmbedLoader.load_without_vocab(path, error='strict')
|
||||
# print(init_embed.shape)
|
||||
# print(init_embed.mean())
|
||||
# print(np.isnan(init_embed).sum())
|
||||
# print(len(vocab))
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from test.models.model_runner import *
|
||||
from .model_runner import *
|
||||
from fastNLP.models.cnn_text_classification import CNNText
|
||||
|
||||
|
||||
@ -16,7 +16,3 @@ class TestCNNText(unittest.TestCase):
|
||||
padding=0,
|
||||
dropout=0.5)
|
||||
RUNNER.run_model_with_task(TEXT_CLS, model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestCNNText().test_case1()
|
36
test/models/test_sequence_labeling.py
Normal file
36
test/models/test_sequence_labeling.py
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from .model_runner import *
|
||||
from fastNLP.models.sequence_labeling import SeqLabeling, AdvSeqLabel
|
||||
from fastNLP.core.losses import LossInForward
|
||||
|
||||
class TesSeqLabel(unittest.TestCase):
|
||||
def test_case1(self):
|
||||
# 测试能否正常运行CNN
|
||||
init_emb = (VOCAB_SIZE, 30)
|
||||
model = SeqLabeling(init_emb,
|
||||
hidden_size=30,
|
||||
num_classes=NUM_CLS)
|
||||
|
||||
data = RUNNER.prepare_pos_tagging_data()
|
||||
data.set_input('target')
|
||||
loss = LossInForward()
|
||||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
|
||||
RUNNER.run_model(model, data, loss, metric)
|
||||
|
||||
|
||||
class TesAdvSeqLabel(unittest.TestCase):
|
||||
def test_case1(self):
|
||||
# 测试能否正常运行CNN
|
||||
init_emb = (VOCAB_SIZE, 30)
|
||||
model = AdvSeqLabel(init_emb,
|
||||
hidden_size=30,
|
||||
num_classes=NUM_CLS)
|
||||
|
||||
data = RUNNER.prepare_pos_tagging_data()
|
||||
data.set_input('target')
|
||||
loss = LossInForward()
|
||||
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET, seq_len=C.INPUT_LEN)
|
||||
RUNNER.run_model(model, data, loss, metric)
|
Loading…
Reference in New Issue
Block a user