1.优化Trainer中对exception的处理;2.修改static_embedding贴合seq2seq

This commit is contained in:
yh_cc 2020-03-01 10:41:48 +08:00
parent 9bed203a35
commit ed7f7b1cd9
2 changed files with 13 additions and 9 deletions

View File

@ -617,6 +617,14 @@ class Trainer(object):
elif on_exception == 'raise':
raise e
if self.dev_data is not None and self.best_dev_perf is not None and load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")
finally:
if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
@ -624,15 +632,7 @@ class Trainer(object):
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
if load_best_model:
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")
finally:
pass
results['seconds'] = round(time.time() - start_time, 2)
return results

View File

@ -76,6 +76,10 @@ class StaticEmbedding(TokenEmbedding):
"""
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim > 0:
if model_dir_or_name is not None:
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
f"set `embedding_dim` to 0.")
model_dir_or_name = None
# 得到cache_path