From 93a4de548955acea7ac654517fc7f1ced7491ef9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 21 Aug 2020 14:26:17 +0200 Subject: [PATCH] Fix lr initialization on reload. --- training/deepspeech_training/train.py | 14 ++++++++++---- training/deepspeech_training/util/checkpoints.py | 13 +++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 155b0693..e8adb119 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -640,17 +640,23 @@ def train(): break # Reduce learning rate on plateau + # If the learning rate was reduced and there is still no improvement + # wait FLAGS.plateau_epochs before the learning rate is reduced again if (FLAGS.reduce_lr_on_plateau and epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0): - # If the learning rate was reduced and there is still no improvement - # wait FLAGS.plateau_epochs before the learning rate is reduced again + + # Reload checkpoint that we use the best_dev weights again + reload_best_checkpoint(session) + + # Reduce learning rate session.run(reduce_learning_rate_op) current_learning_rate = learning_rate_var.eval() log_info('Encountered a plateau, reducing learning rate to {}'.format( current_learning_rate)) - # Reload checkpoint that we use the best_dev weights again - reload_best_checkpoint(session) + # Overwrite best checkpoint with new learning rate value + save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') + log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) if FLAGS.metrics_files: # Read only metrics, not affecting best validation loss tracking diff --git a/training/deepspeech_training/util/checkpoints.py b/training/deepspeech_training/util/checkpoints.py index 27a3dc1c..459a4d06 100644 --- a/training/deepspeech_training/util/checkpoints.py +++ b/training/deepspeech_training/util/checkpoints.py @@ -6,7 +6,7 @@ from .flags import FLAGS from .logging import log_info, log_error, log_warn -def _load_checkpoint(session, checkpoint_path, allow_drop_layers): +def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): # Load the checkpoint and put all variables into loading list # we will exclude variables we do not wish to load and then # we will initialize them instead @@ -18,7 +18,8 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers): # We explicitly allow the learning rate variable to be missing for backwards # compatibility with older checkpoints. lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') - if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate): + if lr_var and ('learning_rate' not in vars_in_ckpt or + (FLAGS.force_initialize_learning_rate and allow_lr_init)): assert len(lr_var) <= 1 load_vars -= lr_var init_vars |= lr_var @@ -87,14 +88,14 @@ def _initialize_all_variables(session): session.run(v.initializer) -def _load_or_init_impl(session, method_order, allow_drop_layers): +def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): for method in method_order: # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' if method == 'best': ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint') if ckpt_path: log_info('Loading best validating checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path, allow_drop_layers) + return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) log_info('Could not find best validating checkpoint.') # Load most recent checkpoint, saved in checkpoint file 'checkpoint' @@ -102,7 +103,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers): ckpt_path = _checkpoint_path_or_none('checkpoint') if ckpt_path: log_info('Loading most recent checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path, allow_drop_layers) + return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) log_info('Could not find most recent checkpoint.') # Initialize all variables @@ -119,7 +120,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers): def reload_best_checkpoint(session): - _load_or_init_impl(session, ['best'], allow_drop_layers=False) + _load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False) def load_or_init_graph_for_training(session):