mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
增加char_embedding可使用预训练的character embedding的功能
This commit is contained in:
parent
f6bb8c83f3
commit
fb82c66b4c
@ -9,6 +9,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from .static_embedding import StaticEmbedding
|
||||||
from ..modules.encoder.lstm import LSTM
|
from ..modules.encoder.lstm import LSTM
|
||||||
from ..core.vocabulary import Vocabulary
|
from ..core.vocabulary import Vocabulary
|
||||||
from .embedding import TokenEmbedding
|
from .embedding import TokenEmbedding
|
||||||
@ -41,10 +42,13 @@ class CNNCharEmbedding(TokenEmbedding):
|
|||||||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
|
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
|
||||||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
|
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
|
||||||
:param min_char_freq: character的最少出现次数。默认值为2.
|
:param min_char_freq: character的最少出现次数。默认值为2.
|
||||||
|
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
|
||||||
|
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
|
||||||
|
如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
|
||||||
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
|
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
|
||||||
pool_method: str='max', activation='relu', min_char_freq: int=2):
|
pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=''):
|
||||||
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||||
|
|
||||||
for kernel in kernel_sizes:
|
for kernel in kernel_sizes:
|
||||||
@ -85,7 +89,11 @@ class CNNCharEmbedding(TokenEmbedding):
|
|||||||
self.words_to_chars_embedding[index, :len(word)] = \
|
self.words_to_chars_embedding[index, :len(word)] = \
|
||||||
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
|
||||||
self.word_lengths[index] = len(word)
|
self.word_lengths[index] = len(word)
|
||||||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
||||||
|
if len(pre_train_char_embed):
|
||||||
|
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
|
||||||
|
else:
|
||||||
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
||||||
|
|
||||||
self.convs = nn.ModuleList([nn.Conv1d(
|
self.convs = nn.ModuleList([nn.Conv1d(
|
||||||
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
|
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
|
||||||
@ -184,10 +192,13 @@ class LSTMCharEmbedding(TokenEmbedding):
|
|||||||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
|
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
|
||||||
:param min_char_freq: character的最小出现次数。默认值为2.
|
:param min_char_freq: character的最小出现次数。默认值为2.
|
||||||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
|
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
|
||||||
|
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
|
||||||
|
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
|
||||||
|
如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
|
||||||
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
|
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
|
||||||
bidirectional=True):
|
bidirectional=True, pre_train_char_embed: str=''):
|
||||||
super(LSTMCharEmbedding, self).__init__(vocab)
|
super(LSTMCharEmbedding, self).__init__(vocab)
|
||||||
|
|
||||||
assert hidden_size % 2 == 0, "Only even kernel is allowed."
|
assert hidden_size % 2 == 0, "Only even kernel is allowed."
|
||||||
@ -227,7 +238,11 @@ class LSTMCharEmbedding(TokenEmbedding):
|
|||||||
self.words_to_chars_embedding[index, :len(word)] = \
|
self.words_to_chars_embedding[index, :len(word)] = \
|
||||||
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word])
|
||||||
self.word_lengths[index] = len(word)
|
self.word_lengths[index] = len(word)
|
||||||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
||||||
|
if len(pre_train_char_embed):
|
||||||
|
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
|
||||||
|
else:
|
||||||
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
|
||||||
|
|
||||||
self.fc = nn.Linear(hidden_size, embed_size)
|
self.fc = nn.Linear(hidden_size, embed_size)
|
||||||
hidden_size = hidden_size // 2 if bidirectional else hidden_size
|
hidden_size = hidden_size // 2 if bidirectional else hidden_size
|
||||||
|
Loading…
Reference in New Issue
Block a user