mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
fix bugs in model/bert.py and add testing codes
This commit is contained in:
parent
e206cae45c
commit
016f02be3b
@ -10,6 +10,35 @@ from ..core.const import Const
|
||||
from ..modules.encoder import BertModel
|
||||
|
||||
|
||||
class BertConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class BertForSequenceClassification(BaseModel):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel):
|
||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_labels = 2
|
||||
model = BertForSequenceClassification(config, num_labels)
|
||||
model = BertForSequenceClassification(num_labels, config)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels, bert_dir):
|
||||
def __init__(self, num_labels, config=None, bert_dir=None):
|
||||
super(BertForSequenceClassification, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel):
|
||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_choices = 2
|
||||
model = BertForMultipleChoice(config, num_choices, bert_dir)
|
||||
model = BertForMultipleChoice(num_choices, config, bert_dir)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices, bert_dir):
|
||||
def __init__(self, num_choices, config=None, bert_dir=None):
|
||||
super(BertForMultipleChoice, self).__init__()
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel):
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
num_labels = 2
|
||||
bert_dir = 'your-bert-file-dir'
|
||||
model = BertForTokenClassification(config, num_labels, bert_dir)
|
||||
model = BertForTokenClassification(num_labels, config, bert_dir)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels, bert_dir):
|
||||
def __init__(self, num_labels, config=None, bert_dir=None):
|
||||
super(BertForTokenClassification, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel):
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, bert_dir):
|
||||
def __init__(self, config=None, bert_dir=None):
|
||||
super(BertForQuestionAnswering, self).__init__()
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
if bert_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(bert_dir)
|
||||
else:
|
||||
if config is None:
|
||||
config = BertConfig()
|
||||
self.bert = BertModel(**config.__dict__)
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
@ -2,20 +2,64 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.models.bert import BertModel
|
||||
from fastNLP.models.bert import *
|
||||
|
||||
|
||||
class TestBert(unittest.TestCase):
|
||||
def test_bert_1(self):
|
||||
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese")
|
||||
model = BertModel(vocab_size=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForSequenceClassification(2)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
for layer in all_encoder_layers:
|
||||
self.assertEqual(tuple(layer.shape), (2, 3, 768))
|
||||
self.assertEqual(tuple(pooled_output.shape), (2, 768))
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
|
||||
|
||||
def test_bert_2(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForMultipleChoice(2)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
|
||||
|
||||
def test_bert_3(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForTokenClassification(7)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUT in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
|
||||
|
||||
def test_bert_4(self):
|
||||
from fastNLP.core.const import Const
|
||||
|
||||
model = BertForQuestionAnswering()
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
pred = model(input_ids, token_type_ids, input_mask)
|
||||
self.assertTrue(isinstance(pred, dict))
|
||||
self.assertTrue(Const.OUTPUTS(0) in pred)
|
||||
self.assertTrue(Const.OUTPUTS(1) in pred)
|
||||
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
|
||||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))
|
||||
|
21
test/modules/encoder/test_bert.py
Normal file
21
test/modules/encoder/test_bert.py
Normal file
@ -0,0 +1,21 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.models.bert import BertModel
|
||||
|
||||
|
||||
class TestBert(unittest.TestCase):
|
||||
def test_bert_1(self):
|
||||
model = BertModel(vocab_size=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
for layer in all_encoder_layers:
|
||||
self.assertEqual(tuple(layer.shape), (2, 3, 768))
|
||||
self.assertEqual(tuple(pooled_output.shape), (2, 768))
|
Loading…
Reference in New Issue
Block a user