mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
update some function about bert and roberta
This commit is contained in:
parent
84776696cd
commit
030e0aa3ee
@ -110,11 +110,12 @@ class BertEmbedding(ContextualEmbedding):
|
||||
if '[CLS]' in vocab:
|
||||
self._word_cls_index = vocab['[CLS]']
|
||||
|
||||
min_freq = kwargs.get('min_freq', 1)
|
||||
min_freq = kwargs.pop('min_freq', 1)
|
||||
self._min_freq = min_freq
|
||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
|
||||
pool_method=pool_method, include_cls_sep=include_cls_sep,
|
||||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate)
|
||||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate,
|
||||
**kwargs)
|
||||
|
||||
self.requires_grad = requires_grad
|
||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
|
||||
@ -367,32 +368,44 @@ class BertWordPieceEncoder(nn.Module):
|
||||
|
||||
class _BertWordModel(nn.Module):
|
||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
|
||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
|
||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(layers, list):
|
||||
self.layers = [int(l) for l in layers]
|
||||
elif isinstance(layers, str):
|
||||
self.layers = list(map(int, layers.split(',')))
|
||||
if layers.lower() == 'all':
|
||||
self.layers = None
|
||||
else:
|
||||
self.layers = list(map(int, layers.split(',')))
|
||||
else:
|
||||
raise TypeError("`layers` only supports str or list[int]")
|
||||
assert len(self.layers) > 0, "There is no layer selected!"
|
||||
|
||||
neg_num_output_layer = -16384
|
||||
pos_num_output_layer = 0
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
neg_num_output_layer = max(layer, neg_num_output_layer)
|
||||
else:
|
||||
pos_num_output_layer = max(layer, pos_num_output_layer)
|
||||
if self.layers is None:
|
||||
neg_num_output_layer = -1
|
||||
else:
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
neg_num_output_layer = max(layer, neg_num_output_layer)
|
||||
else:
|
||||
pos_num_output_layer = max(layer, pos_num_output_layer)
|
||||
|
||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
|
||||
self.encoder = BertModel.from_pretrained(model_dir_or_name,
|
||||
neg_num_output_layer=neg_num_output_layer,
|
||||
pos_num_output_layer=pos_num_output_layer)
|
||||
pos_num_output_layer=pos_num_output_layer,
|
||||
**kwargs)
|
||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings
|
||||
# 检查encoder_layer_number是否合理
|
||||
encoder_layer_number = len(self.encoder.encoder.layer)
|
||||
if self.layers is None:
|
||||
self.layers = [idx for idx in range(encoder_layer_number + 1)]
|
||||
logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 '
|
||||
f'is embedding result): {self.layers}')
|
||||
assert len(self.layers) > 0, "There is no layer selected!"
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
|
||||
@ -417,7 +430,7 @@ class _BertWordModel(nn.Module):
|
||||
word = '[PAD]'
|
||||
elif index == vocab.unknown_idx:
|
||||
word = '[UNK]'
|
||||
elif vocab.word_count[word]<min_freq:
|
||||
elif vocab.word_count[word] < min_freq:
|
||||
word = '[UNK]'
|
||||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
|
||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
|
||||
@ -481,14 +494,15 @@ class _BertWordModel(nn.Module):
|
||||
token_type_ids = torch.zeros_like(word_pieces)
|
||||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
|
||||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
|
||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
|
||||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids,
|
||||
attention_mask=attn_masks,
|
||||
output_all_encoded_layers=True)
|
||||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
|
||||
|
||||
if self.include_cls_sep:
|
||||
s_shift = 1
|
||||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
|
||||
bert_outputs[-1].size(-1))
|
||||
bert_outputs[-1].size(-1))
|
||||
|
||||
else:
|
||||
s_shift = 0
|
||||
|
@ -93,12 +93,13 @@ class RobertaEmbedding(ContextualEmbedding):
|
||||
if '<s>' in vocab:
|
||||
self._word_cls_index = vocab['<s>']
|
||||
|
||||
min_freq = kwargs.get('min_freq', 1)
|
||||
min_freq = kwargs.pop('min_freq', 1)
|
||||
self._min_freq = min_freq
|
||||
|
||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
|
||||
pool_method=pool_method, include_cls_sep=include_cls_sep,
|
||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq)
|
||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
|
||||
**kwargs)
|
||||
self.requires_grad = requires_grad
|
||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
|
||||
|
||||
@ -193,33 +194,45 @@ class RobertaEmbedding(ContextualEmbedding):
|
||||
|
||||
class _RobertaWordModel(nn.Module):
|
||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
|
||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
|
||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(layers, list):
|
||||
self.layers = [int(l) for l in layers]
|
||||
elif isinstance(layers, str):
|
||||
self.layers = list(map(int, layers.split(',')))
|
||||
if layers.lower() == 'all':
|
||||
self.layers = None
|
||||
else:
|
||||
self.layers = list(map(int, layers.split(',')))
|
||||
else:
|
||||
raise TypeError("`layers` only supports str or list[int]")
|
||||
assert len(self.layers) > 0, "There is no layer selected!"
|
||||
|
||||
neg_num_output_layer = -16384
|
||||
pos_num_output_layer = 0
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
neg_num_output_layer = max(layer, neg_num_output_layer)
|
||||
else:
|
||||
pos_num_output_layer = max(layer, pos_num_output_layer)
|
||||
if self.layers is None:
|
||||
neg_num_output_layer = -1
|
||||
else:
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
neg_num_output_layer = max(layer, neg_num_output_layer)
|
||||
else:
|
||||
pos_num_output_layer = max(layer, pos_num_output_layer)
|
||||
|
||||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
|
||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name,
|
||||
neg_num_output_layer=neg_num_output_layer,
|
||||
pos_num_output_layer=pos_num_output_layer)
|
||||
pos_num_output_layer=pos_num_output_layer,
|
||||
**kwargs)
|
||||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
|
||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
|
||||
# 检查encoder_layer_number是否合理
|
||||
encoder_layer_number = len(self.encoder.encoder.layer)
|
||||
if self.layers is None:
|
||||
self.layers = [idx for idx in range(encoder_layer_number + 1)]
|
||||
logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 '
|
||||
f'is embedding result): {self.layers}')
|
||||
assert len(self.layers) > 0, "There is no layer selected!"
|
||||
for layer in self.layers:
|
||||
if layer < 0:
|
||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
|
||||
@ -241,7 +254,7 @@ class _RobertaWordModel(nn.Module):
|
||||
word = '<pad>'
|
||||
elif index == vocab.unknown_idx:
|
||||
word = '<unk>'
|
||||
elif vocab.word_count[word]<min_freq:
|
||||
elif vocab.word_count[word] < min_freq:
|
||||
word = '<unk>'
|
||||
word_pieces = self.tokenizer.tokenize(word)
|
||||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces)
|
||||
@ -265,13 +278,15 @@ class _RobertaWordModel(nn.Module):
|
||||
batch_size, max_word_len = words.size()
|
||||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word
|
||||
seq_len = word_mask.sum(dim=-1)
|
||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len
|
||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
|
||||
0) # batch_size x max_len
|
||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
|
||||
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
|
||||
if max_word_piece_length + 2 > self._max_position_embeddings:
|
||||
if self.auto_truncate:
|
||||
word_pieces_lengths = word_pieces_lengths.masked_fill(
|
||||
word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2)
|
||||
word_pieces_lengths + 2 > self._max_position_embeddings,
|
||||
self._max_position_embeddings - 2)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"After split words into word pieces, the lengths of word pieces are longer than the "
|
||||
@ -290,6 +305,7 @@ class _RobertaWordModel(nn.Module):
|
||||
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
|
||||
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
|
||||
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
|
||||
# 添加<s>和</s>
|
||||
word_pieces[:, 0].fill_(self._cls_index)
|
||||
batch_indexes = torch.arange(batch_size).to(words)
|
||||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
|
||||
@ -362,6 +378,12 @@ class _RobertaWordModel(nn.Module):
|
||||
return outputs
|
||||
|
||||
def save(self, folder):
|
||||
"""
|
||||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt
|
||||
|
||||
:param str folder:
|
||||
:return:
|
||||
"""
|
||||
self.tokenizer.save_pretrained(folder)
|
||||
self.encoder.save_pretrained(folder)
|
||||
|
||||
|
@ -184,21 +184,23 @@ class DistilBertEmbeddings(nn.Module):
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids):
|
||||
def forward(self, input_ids, token_type_ids, position_ids=None):
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, max_seq_length)
|
||||
The token ids to embed.
|
||||
token_type_ids: no used.
|
||||
position_ids: no used.
|
||||
Outputs
|
||||
-------
|
||||
embeddings: torch.tensor(bs, max_seq_length, dim)
|
||||
The embedded tokens (plus position embeddings, no token_type embeddings)
|
||||
"""
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
||||
|
||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
||||
|
Loading…
Reference in New Issue
Block a user