mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修复save_path在dev为空的时候的bug
This commit is contained in:
parent
720d7b8729
commit
40bec21684
@ -178,7 +178,7 @@ class LossFunc(LossBase):
|
||||
r"""
|
||||
提供给用户使用自定义损失函数的类
|
||||
|
||||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect
|
||||
:param func: 用户自行定义的损失函数,应当为一个函数。
|
||||
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。
|
||||
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
|
||||
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数
|
||||
@ -186,8 +186,8 @@ class LossFunc(LossBase):
|
||||
|
||||
使用方法::
|
||||
|
||||
func = torch.nn.CrossEntropyLoss()
|
||||
loss_func = LossFunc(func, input="pred", target="label")
|
||||
import torch.nn.functional as F
|
||||
loss_func = LossFunc(F.cross_entropy, input="pred", target="label")
|
||||
# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
|
||||
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
|
||||
# 传入func作为一个名为`target`的参数
|
||||
|
@ -630,6 +630,11 @@ class Trainer(object):
|
||||
self.logger.info("Reloaded the best model.")
|
||||
else:
|
||||
self.logger.info("Fail to reload best model.")
|
||||
|
||||
if self.dev_data is None and self.save_path is not None:
|
||||
model_name = "_".join([self.model.__class__.__name__, self.start_time])
|
||||
self._save_model(self.model, model_name)
|
||||
|
||||
finally:
|
||||
if self.dev_data is not None and self.best_dev_perf is not None:
|
||||
self.logger.info(
|
||||
|
@ -89,7 +89,7 @@ class BertEmbedding(ContextualEmbedding):
|
||||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
|
||||
来进行分类的任务将auto_truncate置为True。
|
||||
:param kwargs:
|
||||
int min_freq: 小于该次数的词会被unk代替
|
||||
int min_freq: 小于该次数的词会被unk代替, 默认为1
|
||||
"""
|
||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||
|
||||
@ -110,7 +110,7 @@ class BertEmbedding(ContextualEmbedding):
|
||||
if '[CLS]' in vocab:
|
||||
self._word_cls_index = vocab['CLS']
|
||||
|
||||
min_freq = kwargs.get('min_freq', 2)
|
||||
min_freq = kwargs.get('min_freq', 1)
|
||||
self._min_freq = min_freq
|
||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
|
||||
pool_method=pool_method, include_cls_sep=include_cls_sep,
|
||||
|
@ -83,7 +83,7 @@ class GPT2Embedding(ContextualEmbedding):
|
||||
|
||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
|
||||
truncate_embed = kwargs.get('truncate_embed', True)
|
||||
min_freq = kwargs.get('min_freq', 2)
|
||||
min_freq = kwargs.get('min_freq', 1)
|
||||
|
||||
self.lm_loss =language_model
|
||||
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
|
||||
@ -315,7 +315,7 @@ class GPT2WordPieceEncoder(nn.Module):
|
||||
|
||||
class _GPT2Model(nn.Module):
|
||||
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
|
||||
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False):
|
||||
only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False):
|
||||
super().__init__()
|
||||
|
||||
self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)
|
||||
|
@ -78,7 +78,7 @@ class RobertaEmbedding(ContextualEmbedding):
|
||||
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s>
|
||||
来进行分类的任务将auto_truncate置为True。
|
||||
:param kwargs:
|
||||
int min_freq: 小于该次数的词会被unk代替
|
||||
int min_freq: 小于该次数的词会被unk代替, 默认为1
|
||||
"""
|
||||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||
|
||||
@ -93,7 +93,7 @@ class RobertaEmbedding(ContextualEmbedding):
|
||||
if '<s>' in vocab:
|
||||
self._word_cls_index = vocab['<s>']
|
||||
|
||||
min_freq = kwargs.get('min_freq', 2)
|
||||
min_freq = kwargs.get('min_freq', 1)
|
||||
self._min_freq = min_freq
|
||||
|
||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
|
||||
@ -464,7 +464,7 @@ class RobertaWordPieceEncoder(nn.Module):
|
||||
|
||||
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True)
|
||||
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
|
||||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}")
|
||||
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, folder):
|
||||
|
@ -97,8 +97,8 @@ class StaticEmbedding(TokenEmbedding):
|
||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
|
||||
:param dict kwargs:
|
||||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
|
||||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize;
|
||||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
|
||||
bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize;
|
||||
bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
|
||||
"""
|
||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||
if embedding_dim > 0:
|
||||
|
@ -308,7 +308,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
max_len_eos_mask = max_lengths.eq(cur_len+1)
|
||||
eos_scores = scores[:, _eos_token_id]
|
||||
# 如果已经达到最大长度,就把eos的分数加大
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores)
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)
|
||||
|
||||
if do_sample:
|
||||
if temperature > 0 and temperature != 1:
|
||||
|
@ -76,8 +76,19 @@ class TrainerTestGround(unittest.TestCase):
|
||||
use_tqdm=True, check_code_level=2)
|
||||
trainer.train()
|
||||
import os
|
||||
import shutil
|
||||
self.assertTrue(os.path.exists(save_path))
|
||||
if os.path.exists(save_path):
|
||||
shutil.rmtree(save_path)
|
||||
|
||||
# 无dev_data的训练
|
||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
|
||||
batch_size=32, n_epochs=10, print_every=50, dev_data=None,
|
||||
metrics=None, validate_every=-1, save_path=save_path,
|
||||
use_tqdm=True, check_code_level=2)
|
||||
trainer.train()
|
||||
self.assertTrue(os.path.exists(save_path))
|
||||
if os.path.exists(save_path):
|
||||
import shutil
|
||||
shutil.rmtree(save_path)
|
||||
|
||||
def test_trainer_suggestion1(self):
|
||||
|
Loading…
Reference in New Issue
Block a user