commit
1c9f3bc99d
|
@ -640,17 +640,23 @@ def train():
|
|||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
|
||||
# Reload checkpoint that we use the best_dev weights again
|
||||
reload_best_checkpoint(session)
|
||||
|
||||
# Reduce learning rate
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
# Reload checkpoint that we use the best_dev weights again
|
||||
reload_best_checkpoint(session)
|
||||
# Overwrite best checkpoint with new learning rate value
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
# Read only metrics, not affecting best validation loss tracking
|
||||
|
|
|
@ -6,7 +6,7 @@ from .flags import FLAGS
|
|||
from .logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
|
@ -18,7 +18,8 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
|||
# We explicitly allow the learning rate variable to be missing for backwards
|
||||
# compatibility with older checkpoints.
|
||||
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
|
||||
if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate):
|
||||
if lr_var and ('learning_rate' not in vars_in_ckpt or
|
||||
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
|
||||
assert len(lr_var) <= 1
|
||||
load_vars -= lr_var
|
||||
init_vars |= lr_var
|
||||
|
@ -87,14 +88,14 @@ def _initialize_all_variables(session):
|
|||
session.run(v.initializer)
|
||||
|
||||
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers):
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == 'best':
|
||||
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
|
||||
log_info('Could not find best validating checkpoint.')
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
|
@ -102,7 +103,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
|
|||
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
|
||||
log_info('Could not find most recent checkpoint.')
|
||||
|
||||
# Initialize all variables
|
||||
|
@ -119,7 +120,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
|
|||
|
||||
|
||||
def reload_best_checkpoint(session):
|
||||
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
|
||||
_load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False)
|
||||
|
||||
|
||||
def load_or_init_graph_for_training(session):
|
||||
|
|
Loading…
Reference in New Issue