mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
防止BERTEmbedding在中文场景下被错误使用
This commit is contained in:
parent
f35a4ae2b6
commit
8fe7a4f191
@ -599,7 +599,8 @@ class Trainer(object):
|
||||
self._model_device = _get_model_device(self.model)
|
||||
self._mode(self.model, is_test=False)
|
||||
self._load_best_model = load_best_model
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
# 加上millsecond,防止两个太接近的保存
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f'))
|
||||
start_time = time.time()
|
||||
self.logger.info("training epochs started " + self.start_time)
|
||||
self.step = 0
|
||||
|
@ -294,7 +294,10 @@ class _WordBertModel(nn.Module):
|
||||
word = '[PAD]'
|
||||
elif index == vocab.unknown_idx:
|
||||
word = '[UNK]'
|
||||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
|
||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
|
||||
word_pieces = []
|
||||
for w in _words:
|
||||
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w))
|
||||
if len(word_pieces) == 1:
|
||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
|
||||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面
|
||||
|
@ -989,7 +989,10 @@ class _WordPieceBertModel(nn.Module):
|
||||
def convert_words_to_word_pieces(words):
|
||||
word_pieces = []
|
||||
for word in words:
|
||||
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
|
||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split()
|
||||
tokens = []
|
||||
for word in _words:
|
||||
tokens.extend(self.tokenzier.wordpiece_tokenizer.tokenize(word))
|
||||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
|
||||
word_pieces.extend(word_piece_ids)
|
||||
if add_cls_sep:
|
||||
|
@ -54,6 +54,24 @@ class TrainerTestGround(unittest.TestCase):
|
||||
"""
|
||||
# 应该正确运行
|
||||
"""
|
||||
|
||||
def test_save_path(self):
|
||||
data_set = prepare_fake_dataset()
|
||||
data_set.set_input("x", flag=True)
|
||||
data_set.set_target("y", flag=True)
|
||||
|
||||
train_set, dev_set = data_set.split(0.3)
|
||||
|
||||
model = NaiveClassifier(2, 1)
|
||||
|
||||
save_path = 'test_save_models'
|
||||
|
||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
|
||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,
|
||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=save_path,
|
||||
use_tqdm=True, check_code_level=2)
|
||||
trainer.train()
|
||||
|
||||
|
||||
def test_trainer_suggestion1(self):
|
||||
# 检查报错提示能否正确提醒用户。
|
||||
|
Loading…
Reference in New Issue
Block a user