train增加注释;attention增加注释;新增transformer分词

This commit is contained in:
yh 2019-01-15 14:58:43 +08:00
parent 094a566155
commit 6a0a1ed4ad
3 changed files with 88 additions and 1 deletions

View File

@ -51,7 +51,9 @@ class Trainer(object):
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
it will raise error if some field are not used.
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够
运行但是这个过程理论上不会修改任何参数只是会检查能否运行但如果(1)模型中存在将batch_size写为某个固定值的情况(2)
模型中存在累加前向计算次数的可能会多计算几次建议将check_code_level设置为-1
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add "-" in front of the string. For example::

View File

@ -46,6 +46,21 @@ class DotAtte(nn.Module):
class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
"""
实现的是以下内容
QW1: (batch_size, seq_len, input_size) * (input_size, key_size)
KW2: (batch_size, seq_len, input_size) * (input_size, key_size)
VW3: (batch_size, seq_len, input_size) * (input_size, value_size)
softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为
(batch_size, seq_len, value_size*num_atte)
最终结果将上式过一个value_size*num_atte, output_size)的线性层output为(batch_size, seq_len, output_size)
:param input_size: int, 输入的维度
:param output_size: int, 输出特征的维度
:param key_size: int, query和key映射到该维度
:param value_size: int, value映射到该维度
:param num_atte:
"""
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):

View File

@ -0,0 +1,70 @@
"""
使用transformer作为分词的encoder端
"""
from torch import nn
import torch
from fastNLP.modules.encoder.transformer import TransformerEncoder
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask
from fastNLP.modules.decoder.CRF import allowed_transitions
class TransformerCWS(nn.Module):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4):
super().__init__()
self.embedding = nn.Embedding(vocab_num, embed_dim)
input_size = embed_dim
if bigram_vocab_num:
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim)
input_size += num_bigram_per_char*bigram_embed_dim
self.drop = nn.Dropout(embed_drop_p, inplace=True)
self.fc1 = nn.Linear(input_size, hidden_size)
value_size = hidden_size//num_heads
self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size,
key_size=value_size, value_size=value_size, num_atte=num_heads)
self.fc2 = nn.Linear(hidden_size, tag_size)
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False,
allowed_transitions=allowed_trans)
def forward(self, chars, target, seq_lens, bigrams=None):
seq_lens = seq_lens
masks = seq_len_to_byte_mask(seq_lens)
x = self.embedding(chars)
batch_size = x.size(0)
length = x.size(1)
if hasattr(self, 'bigram_embedding'):
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1)
self.drop(x)
x = self.fc1(x)
feats = self.transformer(x, masks)
feats = self.fc2(feats)
losses = self.crf(feats, target, masks.float())
pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['loss'] = torch.mean(losses)
return pred_dict
if __name__ == '__main__':
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4)
chars = torch.randint(10, size=(4, 7)).long()
bigrams = torch.randint(10, size=(4, 56)).long()
seq_lens = torch.ones(4).long()*7
target = torch.randint(4, size=(4, 7))
print(transformer(chars, target, seq_lens, bigrams))