修复save_path在dev为空的时候的bug

This commit is contained in:
yh_cc 2020-09-19 11:10:36 +08:00
parent 720d7b8729
commit 40bec21684
8 changed files with 30 additions and 14 deletions

View File

@ -178,7 +178,7 @@ class LossFunc(LossBase):
r"""
提供给用户使用自定义损失函数的类
:param func: 用户自行定义的损失函数应当为一个函数或者callable(func)为True的ojbect
:param func: 用户自行定义的损失函数应当为一个函数
:param dict key_map: 参数映射表键为Model/DataSet参数名值为损失函数参数名
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
找到相对应的参数名为value的参数并传入func中作为参数名为key的参数
@ -186,8 +186,8 @@ class LossFunc(LossBase):
使用方法::
func = torch.nn.CrossEntropyLoss()
loss_func = LossFunc(func, input="pred", target="label")
import torch.nn.functional as F
loss_func = LossFunc(F.cross_entropy, input="pred", target="label")
# 这表示构建了一个损失函数类由func计算损失函数其中将从模型返回值或者DataSet的target=True的field
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
# 传入func作为一个名为`target`的参数

View File

@ -630,6 +630,11 @@ class Trainer(object):
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")
if self.dev_data is None and self.save_path is not None:
model_name = "_".join([self.model.__class__.__name__, self.start_time])
self._save_model(self.model, model_name)
finally:
if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info(

View File

@ -89,7 +89,7 @@ class BertEmbedding(ContextualEmbedding):
word pieces后的内容并将第512个word piece置为[SEP]超过长度的部分的encode结果直接全部置零一般仅有只使用[CLS]
来进行分类的任务将auto_truncate置为True
:param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替, 默认为1
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
@ -110,7 +110,7 @@ class BertEmbedding(ContextualEmbedding):
if '[CLS]' in vocab:
self._word_cls_index = vocab['CLS']
min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,

View File

@ -83,7 +83,7 @@ class GPT2Embedding(ContextualEmbedding):
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
truncate_embed = kwargs.get('truncate_embed', True)
min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self.lm_loss =language_model
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
@ -315,7 +315,7 @@ class GPT2WordPieceEncoder(nn.Module):
class _GPT2Model(nn.Module):
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False):
only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False):
super().__init__()
self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)

View File

@ -78,7 +78,7 @@ class RobertaEmbedding(ContextualEmbedding):
word pieces后的内容并将第512个word piece置为</s>超过长度的部分的encode结果直接全部置零一般仅有只使用<s>
来进行分类的任务将auto_truncate置为True
:param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替 默认为1
"""
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
@ -93,7 +93,7 @@ class RobertaEmbedding(ContextualEmbedding):
if '<s>' in vocab:
self._word_cls_index = vocab['<s>']
min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
@ -464,7 +464,7 @@ class RobertaWordPieceEncoder(nn.Module):
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True)
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
logger.debug(f"BertWordPieceEncoder has been saved in {folder}")
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}")
@classmethod
def load(cls, folder):

View File

@ -97,8 +97,8 @@ class StaticEmbedding(TokenEmbedding):
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk
:param dict kwargs:
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词如果该词没有在预训练的词表中出现则为unk如果embedding不需要更新建议设置为True
bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词如果该词没有在预训练的词表中出现则为unk如果embedding不需要更新建议设置为True
"""
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim > 0:

View File

@ -308,7 +308,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)
if do_sample:
if temperature > 0 and temperature != 1:

View File

@ -76,8 +76,19 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=True, check_code_level=2)
trainer.train()
import os
import shutil
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path):
shutil.rmtree(save_path)
# 无dev_data的训练
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=None,
metrics=None, validate_every=-1, save_path=save_path,
use_tqdm=True, check_code_level=2)
trainer.train()
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path):
import shutil
shutil.rmtree(save_path)
def test_trainer_suggestion1(self):