fix bugs in model/bert.py and add testing codes

This commit is contained in:
xuyige 2019-05-29 14:46:48 +08:00
parent e206cae45c
commit 016f02be3b
3 changed files with 133 additions and 19 deletions

View File

@ -10,6 +10,35 @@ from ..core.const import Const
from ..modules.encoder import BertModel
class BertConfig:
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
class BertForSequenceClassification(BaseModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
model = BertForSequenceClassification(num_labels, config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels, bert_dir):
def __init__(self, num_labels, config=None, bert_dir=None):
super(BertForSequenceClassification, self).__init__()
self.num_labels = num_labels
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices, bert_dir)
model = BertForMultipleChoice(num_choices, config, bert_dir)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices, bert_dir):
def __init__(self, num_choices, config=None, bert_dir=None):
super(BertForMultipleChoice, self).__init__()
self.num_choices = num_choices
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel):
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
bert_dir = 'your-bert-file-dir'
model = BertForTokenClassification(config, num_labels, bert_dir)
model = BertForTokenClassification(num_labels, config, bert_dir)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels, bert_dir):
def __init__(self, num_labels, config=None, bert_dir=None):
super(BertForTokenClassification, self).__init__()
self.num_labels = num_labels
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, bert_dir):
def __init__(self, config=None, bert_dir=None):
super(BertForQuestionAnswering, self).__init__()
self.bert = BertModel.from_pretrained(bert_dir)
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
config = BertConfig()
self.bert = BertModel(**config.__dict__)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)

View File

@ -2,20 +2,64 @@ import unittest
import torch
from fastNLP.models.bert import BertModel
from fastNLP.models.bert import *
class TestBert(unittest.TestCase):
def test_bert_1(self):
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese")
model = BertModel(vocab_size=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
from fastNLP.core.const import Const
model = BertForSequenceClassification(2)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
self.assertEqual(tuple(pooled_output.shape), (2, 768))
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
def test_bert_2(self):
from fastNLP.core.const import Const
model = BertForMultipleChoice(2)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
def test_bert_3(self):
from fastNLP.core.const import Const
model = BertForTokenClassification(7)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
def test_bert_4(self):
from fastNLP.core.const import Const
model = BertForQuestionAnswering()
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
pred = model(input_ids, token_type_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUTS(0) in pred)
self.assertTrue(Const.OUTPUTS(1) in pred)
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))

View File

@ -0,0 +1,21 @@
import unittest
import torch
from fastNLP.models.bert import BertModel
class TestBert(unittest.TestCase):
def test_bert_1(self):
model = BertModel(vocab_size=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
self.assertEqual(tuple(pooled_output.shape), (2, 768))