mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
1.优化Trainer中对exception的处理;2.修改static_embedding贴合seq2seq
This commit is contained in:
parent
9bed203a35
commit
ed7f7b1cd9
@ -617,6 +617,14 @@ class Trainer(object):
|
||||
elif on_exception == 'raise':
|
||||
raise e
|
||||
|
||||
if self.dev_data is not None and self.best_dev_perf is not None and load_best_model:
|
||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||
load_succeed = self._load_model(self.model, model_name)
|
||||
if load_succeed:
|
||||
self.logger.info("Reloaded the best model.")
|
||||
else:
|
||||
self.logger.info("Fail to reload best model.")
|
||||
finally:
|
||||
if self.dev_data is not None and self.best_dev_perf is not None:
|
||||
self.logger.info(
|
||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
|
||||
@ -624,15 +632,7 @@ class Trainer(object):
|
||||
results['best_eval'] = self.best_dev_perf
|
||||
results['best_epoch'] = self.best_dev_epoch
|
||||
results['best_step'] = self.best_dev_step
|
||||
if load_best_model:
|
||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||
load_succeed = self._load_model(self.model, model_name)
|
||||
if load_succeed:
|
||||
self.logger.info("Reloaded the best model.")
|
||||
else:
|
||||
self.logger.info("Fail to reload best model.")
|
||||
finally:
|
||||
pass
|
||||
|
||||
results['seconds'] = round(time.time() - start_time, 2)
|
||||
|
||||
return results
|
||||
|
@ -76,6 +76,10 @@ class StaticEmbedding(TokenEmbedding):
|
||||
"""
|
||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||
if embedding_dim > 0:
|
||||
if model_dir_or_name is not None:
|
||||
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
|
||||
f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
|
||||
f"set `embedding_dim` to 0.")
|
||||
model_dir_or_name = None
|
||||
|
||||
# 得到cache_path
|
||||
|
Loading…
Reference in New Issue
Block a user