Fix lr initialization on reload.

This commit is contained in:
Daniel 2020-08-21 14:26:17 +02:00 committed by Reuben Morais
parent 8965b29e81
commit 93a4de5489
2 changed files with 17 additions and 10 deletions

View File

@ -640,17 +640,23 @@ def train():
break break
# Reduce learning rate on plateau # Reduce learning rate on plateau
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
if (FLAGS.reduce_lr_on_plateau and if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0): epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again # Reload checkpoint that we use the best_dev weights again
reload_best_checkpoint(session)
# Reduce learning rate
session.run(reduce_learning_rate_op) session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval() current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format( log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate)) current_learning_rate))
# Reload checkpoint that we use the best_dev weights again # Overwrite best checkpoint with new learning rate value
reload_best_checkpoint(session) save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
if FLAGS.metrics_files: if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking # Read only metrics, not affecting best validation loss tracking

View File

@ -6,7 +6,7 @@ from .flags import FLAGS
from .logging import log_info, log_error, log_warn from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers): def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
# Load the checkpoint and put all variables into loading list # Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then # we will exclude variables we do not wish to load and then
# we will initialize them instead # we will initialize them instead
@ -18,7 +18,8 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
# We explicitly allow the learning rate variable to be missing for backwards # We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints. # compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate): if lr_var and ('learning_rate' not in vars_in_ckpt or
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
assert len(lr_var) <= 1 assert len(lr_var) <= 1
load_vars -= lr_var load_vars -= lr_var
init_vars |= lr_var init_vars |= lr_var
@ -87,14 +88,14 @@ def _initialize_all_variables(session):
session.run(v.initializer) session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers): def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
for method in method_order: for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best': if method == 'best':
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint') ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
if ckpt_path: if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path)) log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers) return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
log_info('Could not find best validating checkpoint.') log_info('Could not find best validating checkpoint.')
# Load most recent checkpoint, saved in checkpoint file 'checkpoint' # Load most recent checkpoint, saved in checkpoint file 'checkpoint'
@ -102,7 +103,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
ckpt_path = _checkpoint_path_or_none('checkpoint') ckpt_path = _checkpoint_path_or_none('checkpoint')
if ckpt_path: if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path)) log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers) return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
log_info('Could not find most recent checkpoint.') log_info('Could not find most recent checkpoint.')
# Initialize all variables # Initialize all variables
@ -119,7 +120,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
def reload_best_checkpoint(session): def reload_best_checkpoint(session):
_load_or_init_impl(session, ['best'], allow_drop_layers=False) _load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session): def load_or_init_graph_for_training(session):