增加char_embedding可使用预训练的character embedding的功能

This commit is contained in:
YanqunJiang 2019-08-16 17:51:07 +08:00
parent f6bb8c83f3
commit fb82c66b4c

View File

@ -9,6 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import List
from .static_embedding import StaticEmbedding
from ..modules.encoder.lstm import LSTM
from ..core.vocabulary import Vocabulary
from .embedding import TokenEmbedding
@ -41,10 +42,13 @@ class CNNCharEmbedding(TokenEmbedding):
:param pool_method: character的表示在合成一个表示时所使用的pool方法支持'avg', 'max'.
:param activation: CNN之后使用的激活方法支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
:param min_char_freq: character的最少出现次数默认值为2.
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding第一种是传入embedding文件夹(文件夹下应该只有一个
.txt作为后缀的文件)或文件路径第二种是传入embedding的名称第二种情况将自动查看缓存中是否存在该模型没有的话将自动下载
如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
pool_method: str='max', activation='relu', min_char_freq: int=2):
pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=''):
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
for kernel in kernel_sizes:
@ -85,7 +89,11 @@ class CNNCharEmbedding(TokenEmbedding):
self.words_to_chars_embedding[index, :len(word)] = \
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.word_lengths[index] = len(word)
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
if len(pre_train_char_embed):
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
else:
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
self.convs = nn.ModuleList([nn.Conv1d(
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
@ -184,10 +192,13 @@ class LSTMCharEmbedding(TokenEmbedding):
:param activation: 激活函数支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
:param min_char_freq: character的最小出现次数默认值为2.
:param bidirectional: 是否使用双向的LSTM进行encode默认值为True
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding第一种是传入embedding文件夹(文件夹下应该只有一个
.txt作为后缀的文件)或文件路径第二种是传入embedding的名称第二种情况将自动查看缓存中是否存在该模型没有的话将自动下载
如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
bidirectional=True):
bidirectional=True, pre_train_char_embed: str=''):
super(LSTMCharEmbedding, self).__init__(vocab)
assert hidden_size % 2 == 0, "Only even kernel is allowed."
@ -227,7 +238,11 @@ class LSTMCharEmbedding(TokenEmbedding):
self.words_to_chars_embedding[index, :len(word)] = \
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
self.word_lengths[index] = len(word)
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
if len(pre_train_char_embed):
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
else:
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
self.fc = nn.Linear(hidden_size, embed_size)
hidden_size = hidden_size // 2 if bidirectional else hidden_size