防止BERTEmbedding在中文场景下被错误使用

This commit is contained in:
yh_cc 2020-03-19 14:22:15 +08:00
parent f35a4ae2b6
commit 8fe7a4f191
4 changed files with 28 additions and 3 deletions

View File

@ -599,7 +599,8 @@ class Trainer(object):
self._model_device = _get_model_device(self.model)
self._mode(self.model, is_test=False)
self._load_best_model = load_best_model
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
# 加上millsecond防止两个太接近的保存
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f'))
start_time = time.time()
self.logger.info("training epochs started " + self.start_time)
self.step = 0

View File

@ -294,7 +294,10 @@ class _WordBertModel(nn.Module):
word = '[PAD]'
elif index == vocab.unknown_idx:
word = '[UNK]'
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
word_pieces = []
for w in _words:
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w))
if len(word_pieces) == 1:
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面

View File

@ -989,7 +989,10 @@ class _WordPieceBertModel(nn.Module):
def convert_words_to_word_pieces(words):
word_pieces = []
for word in words:
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
tokens = []
for word in _words:
tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word))
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
word_pieces.extend(word_piece_ids)
if add_cls_sep:

View File

@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase):
"""
# 应该正确运行
"""
def test_save_path(self):
data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True)
train_set, dev_set = data_set.split(0.3)
model = NaiveClassifier(2, 1)
save_path = 'test_save_models'
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path,
use_tqdm=True, check_code_level=2)
trainer.train()
def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。