mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
1.优化Trainer中对exception的处理;2.修改static_embedding贴合seq2seq
This commit is contained in:
parent
9bed203a35
commit
ed7f7b1cd9
@ -617,6 +617,14 @@ class Trainer(object):
|
|||||||
elif on_exception == 'raise':
|
elif on_exception == 'raise':
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if self.dev_data is not None and self.best_dev_perf is not None and load_best_model:
|
||||||
|
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||||
|
load_succeed = self._load_model(self.model, model_name)
|
||||||
|
if load_succeed:
|
||||||
|
self.logger.info("Reloaded the best model.")
|
||||||
|
else:
|
||||||
|
self.logger.info("Fail to reload best model.")
|
||||||
|
finally:
|
||||||
if self.dev_data is not None and self.best_dev_perf is not None:
|
if self.dev_data is not None and self.best_dev_perf is not None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
|
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
|
||||||
@ -624,15 +632,7 @@ class Trainer(object):
|
|||||||
results['best_eval'] = self.best_dev_perf
|
results['best_eval'] = self.best_dev_perf
|
||||||
results['best_epoch'] = self.best_dev_epoch
|
results['best_epoch'] = self.best_dev_epoch
|
||||||
results['best_step'] = self.best_dev_step
|
results['best_step'] = self.best_dev_step
|
||||||
if load_best_model:
|
|
||||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
|
||||||
load_succeed = self._load_model(self.model, model_name)
|
|
||||||
if load_succeed:
|
|
||||||
self.logger.info("Reloaded the best model.")
|
|
||||||
else:
|
|
||||||
self.logger.info("Fail to reload best model.")
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
results['seconds'] = round(time.time() - start_time, 2)
|
results['seconds'] = round(time.time() - start_time, 2)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -76,6 +76,10 @@ class StaticEmbedding(TokenEmbedding):
|
|||||||
"""
|
"""
|
||||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
|
||||||
if embedding_dim > 0:
|
if embedding_dim > 0:
|
||||||
|
if model_dir_or_name is not None:
|
||||||
|
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with"
|
||||||
|
f" dimension {embedding_dim}. If you want to use pre-trained embedding, "
|
||||||
|
f"set `embedding_dim` to 0.")
|
||||||
model_dir_or_name = None
|
model_dir_or_name = None
|
||||||
|
|
||||||
# 得到cache_path
|
# 得到cache_path
|
||||||
|
Loading…
Reference in New Issue
Block a user