mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
train增加注释;attention增加注释;新增transformer分词
This commit is contained in:
parent
094a566155
commit
6a0a1ed4ad
@ -51,7 +51,9 @@ class Trainer(object):
|
||||
:param Optimizer optimizer: an optimizer object
|
||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\
|
||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
|
||||
it will raise error if some field are not used.
|
||||
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够
|
||||
运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况,;(2)
|
||||
模型中存在累加前向计算次数的,可能会多计算几次。建议将check_code_level设置为-1
|
||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
|
||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
|
||||
smaller, add "-" in front of the string. For example::
|
||||
|
@ -46,6 +46,21 @@ class DotAtte(nn.Module):
|
||||
|
||||
class MultiHeadAtte(nn.Module):
|
||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
|
||||
"""
|
||||
实现的是以下内容
|
||||
QW1: (batch_size, seq_len, input_size) * (input_size, key_size)
|
||||
KW2: (batch_size, seq_len, input_size) * (input_size, key_size)
|
||||
VW3: (batch_size, seq_len, input_size) * (input_size, value_size)
|
||||
|
||||
softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为
|
||||
(batch_size, seq_len, value_size*num_atte)
|
||||
最终结果将上式过一个(value_size*num_atte, output_size)的线性层,output为(batch_size, seq_len, output_size)
|
||||
:param input_size: int, 输入的维度
|
||||
:param output_size: int, 输出特征的维度
|
||||
:param key_size: int, query和key映射到该维度
|
||||
:param value_size: int, value映射到该维度
|
||||
:param num_atte:
|
||||
"""
|
||||
super(MultiHeadAtte, self).__init__()
|
||||
self.in_linear = nn.ModuleList()
|
||||
for i in range(num_atte * 3):
|
||||
|
70
reproduction/chinese_word_segment/models/cws_transformer.py
Normal file
70
reproduction/chinese_word_segment/models/cws_transformer.py
Normal file
@ -0,0 +1,70 @@
|
||||
|
||||
|
||||
|
||||
"""
|
||||
使用transformer作为分词的encoder端
|
||||
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
from fastNLP.modules.encoder.transformer import TransformerEncoder
|
||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask
|
||||
from fastNLP.modules.decoder.CRF import allowed_transitions
|
||||
|
||||
class TransformerCWS(nn.Module):
|
||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
|
||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_num, embed_dim)
|
||||
input_size = embed_dim
|
||||
if bigram_vocab_num:
|
||||
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim)
|
||||
input_size += num_bigram_per_char*bigram_embed_dim
|
||||
|
||||
self.drop = nn.Dropout(embed_drop_p, inplace=True)
|
||||
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
|
||||
value_size = hidden_size//num_heads
|
||||
self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size,
|
||||
key_size=value_size, value_size=value_size, num_atte=num_heads)
|
||||
|
||||
self.fc2 = nn.Linear(hidden_size, tag_size)
|
||||
|
||||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
|
||||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False,
|
||||
allowed_transitions=allowed_trans)
|
||||
|
||||
def forward(self, chars, target, seq_lens, bigrams=None):
|
||||
seq_lens = seq_lens
|
||||
masks = seq_len_to_byte_mask(seq_lens)
|
||||
x = self.embedding(chars)
|
||||
batch_size = x.size(0)
|
||||
length = x.size(1)
|
||||
if hasattr(self, 'bigram_embedding'):
|
||||
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size
|
||||
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1)
|
||||
self.drop(x)
|
||||
x = self.fc1(x)
|
||||
feats = self.transformer(x, masks)
|
||||
feats = self.fc2(feats)
|
||||
losses = self.crf(feats, target, masks.float())
|
||||
|
||||
pred_dict = {}
|
||||
pred_dict['seq_lens'] = seq_lens
|
||||
pred_dict['loss'] = torch.mean(losses)
|
||||
|
||||
return pred_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8,
|
||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4)
|
||||
chars = torch.randint(10, size=(4, 7)).long()
|
||||
bigrams = torch.randint(10, size=(4, 56)).long()
|
||||
seq_lens = torch.ones(4).long()*7
|
||||
target = torch.randint(4, size=(4, 7))
|
||||
|
||||
print(transformer(chars, target, seq_lens, bigrams))
|
Loading…
Reference in New Issue
Block a user