mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
115 lines
5.6 KiB
Python
115 lines
5.6 KiB
Python
|
|
import unittest
|
|
|
|
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
|
|
from fastNLP import Vocabulary
|
|
from fastNLP.embeddings import StaticEmbedding
|
|
import torch
|
|
from torch import optim
|
|
import torch.nn.functional as F
|
|
from fastNLP import seq_len_to_mask
|
|
|
|
|
|
def prepare_env():
|
|
vocab = Vocabulary().add_word_lst("This is a test .".split())
|
|
vocab.add_word_lst("Another test !".split())
|
|
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
|
|
|
|
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
|
|
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
|
|
src_seq_len = torch.LongTensor([3, 2])
|
|
tgt_seq_len = torch.LongTensor([4, 2])
|
|
|
|
return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len
|
|
|
|
|
|
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len):
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-2)
|
|
mask = seq_len_to_mask(tgt_seq_len).eq(0)
|
|
target = tgt_words_idx.masked_fill(mask, -100)
|
|
|
|
for i in range(100):
|
|
optimizer.zero_grad()
|
|
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size
|
|
loss = F.cross_entropy(pred.transpose(1, 2), target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum()
|
|
return right_count
|
|
|
|
|
|
class TestTransformerSeq2SeqModel(unittest.TestCase):
|
|
def test_run(self):
|
|
# 测试能否跑通
|
|
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
|
|
for pos_embed in ['learned', 'sin']:
|
|
with self.subTest(pos_embed=pos_embed):
|
|
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
|
|
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
|
|
bind_encoder_decoder_embed=True,
|
|
bind_decoder_input_output_embed=True)
|
|
|
|
output = model(src_words_idx, tgt_words_idx, src_seq_len)
|
|
self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
|
|
|
|
for bind_encoder_decoder_embed in [True, False]:
|
|
tgt_embed = None
|
|
for bind_decoder_input_output_embed in [True, False]:
|
|
if bind_encoder_decoder_embed == False:
|
|
tgt_embed = embed
|
|
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed,
|
|
bind_decoder_input_output_embed=bind_decoder_input_output_embed):
|
|
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
|
|
pos_embed='sin', max_position=20, num_layers=2,
|
|
d_model=30, n_head=6, dim_ff=20, dropout=0.1,
|
|
bind_encoder_decoder_embed=bind_encoder_decoder_embed,
|
|
bind_decoder_input_output_embed=bind_decoder_input_output_embed)
|
|
|
|
output = model(src_words_idx, tgt_words_idx, src_seq_len)
|
|
self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
|
|
|
|
def test_train(self):
|
|
# 测试能否train到overfit
|
|
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
|
|
|
|
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
|
|
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
|
|
bind_encoder_decoder_embed=True,
|
|
bind_decoder_input_output_embed=True)
|
|
|
|
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
|
|
self.assertEqual(right_count, tgt_words_idx.nelement())
|
|
|
|
|
|
class TestLSTMSeq2SeqModel(unittest.TestCase):
|
|
def test_run(self):
|
|
# 测试能否跑通
|
|
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
|
|
|
|
for bind_encoder_decoder_embed in [True, False]:
|
|
tgt_embed = None
|
|
for bind_decoder_input_output_embed in [True, False]:
|
|
if bind_encoder_decoder_embed == False:
|
|
tgt_embed = embed
|
|
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed,
|
|
bind_decoder_input_output_embed=bind_decoder_input_output_embed):
|
|
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
|
|
num_layers=2, hidden_size=20, dropout=0.1,
|
|
bind_encoder_decoder_embed=bind_encoder_decoder_embed,
|
|
bind_decoder_input_output_embed=bind_decoder_input_output_embed)
|
|
output = model(src_words_idx, tgt_words_idx, src_seq_len)
|
|
self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
|
|
|
|
def test_train(self):
|
|
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
|
|
|
|
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
|
|
num_layers=1, hidden_size=20, dropout=0.1,
|
|
bind_encoder_decoder_embed=True,
|
|
bind_decoder_input_output_embed=True)
|
|
|
|
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
|
|
self.assertEqual(right_count, tgt_words_idx.nelement())
|
|
|