diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 109315a3..add86156 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -51,7 +51,9 @@ class Trainer(object): :param Optimizer optimizer: an optimizer object :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means - it will raise error if some field are not used. + it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是否能够 + 运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况,;(2) + 模型中存在累加前向计算次数的,可能会多计算几次。建议将check_code_level设置为-1 :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets smaller, add "-" in front of the string. For example:: diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 3fea1b10..9f7d72dc 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -46,6 +46,21 @@ class DotAtte(nn.Module): class MultiHeadAtte(nn.Module): def __init__(self, input_size, output_size, key_size, value_size, num_atte): + """ + 实现的是以下内容 + QW1: (batch_size, seq_len, input_size) * (input_size, key_size) + KW2: (batch_size, seq_len, input_size) * (input_size, key_size) + VW3: (batch_size, seq_len, input_size) * (input_size, value_size) + + softmax(QK^T/sqrt(scale))*V: (batch_size, seq_len, value_size) 多个head(num_atten指定)的结果为 + (batch_size, seq_len, value_size*num_atte) + 最终结果将上式过一个(value_size*num_atte, output_size)的线性层,output为(batch_size, seq_len, output_size) + :param input_size: int, 输入的维度 + :param output_size: int, 输出特征的维度 + :param key_size: int, query和key映射到该维度 + :param value_size: int, value映射到该维度 + :param num_atte: + """ super(MultiHeadAtte, self).__init__() self.in_linear = nn.ModuleList() for i in range(num_atte * 3): diff --git a/reproduction/chinese_word_segment/models/cws_transformer.py b/reproduction/chinese_word_segment/models/cws_transformer.py new file mode 100644 index 00000000..3fcf91b5 --- /dev/null +++ b/reproduction/chinese_word_segment/models/cws_transformer.py @@ -0,0 +1,70 @@ + + + +""" +使用transformer作为分词的encoder端 + +""" + +from torch import nn +import torch +from fastNLP.modules.encoder.transformer import TransformerEncoder +from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask +from fastNLP.modules.decoder.CRF import allowed_transitions + +class TransformerCWS(nn.Module): + def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, + hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4): + super().__init__() + + self.embedding = nn.Embedding(vocab_num, embed_dim) + input_size = embed_dim + if bigram_vocab_num: + self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) + input_size += num_bigram_per_char*bigram_embed_dim + + self.drop = nn.Dropout(embed_drop_p, inplace=True) + + self.fc1 = nn.Linear(input_size, hidden_size) + + value_size = hidden_size//num_heads + self.transformer = TransformerEncoder(num_layers, input_size=input_size, output_size=hidden_size, + key_size=value_size, value_size=value_size, num_atte=num_heads) + + self.fc2 = nn.Linear(hidden_size, tag_size) + + allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') + self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, + allowed_transitions=allowed_trans) + + def forward(self, chars, target, seq_lens, bigrams=None): + seq_lens = seq_lens + masks = seq_len_to_byte_mask(seq_lens) + x = self.embedding(chars) + batch_size = x.size(0) + length = x.size(1) + if hasattr(self, 'bigram_embedding'): + bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size + x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) + self.drop(x) + x = self.fc1(x) + feats = self.transformer(x, masks) + feats = self.fc2(feats) + losses = self.crf(feats, target, masks.float()) + + pred_dict = {} + pred_dict['seq_lens'] = seq_lens + pred_dict['loss'] = torch.mean(losses) + + return pred_dict + + +if __name__ == '__main__': + transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, + hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) + chars = torch.randint(10, size=(4, 7)).long() + bigrams = torch.randint(10, size=(4, 56)).long() + seq_lens = torch.ones(4).long()*7 + target = torch.randint(4, size=(4, 7)) + + print(transformer(chars, target, seq_lens, bigrams)) \ No newline at end of file