From 0de9e4bf80cf2f9cc7c848b9189811aa32a570aa Mon Sep 17 00:00:00 2001 From: Richard Hamnett Date: Fri, 21 Feb 2020 18:32:03 +0000 Subject: [PATCH 1/3] Add force_initialize_learning_rate Ability to reset learning rate which has been reduced by reduce_lr_on_plateau --- util/flags.py | 1 + 1 file changed, 1 insertion(+) diff --git a/util/flags.py b/util/flags.py index f46fdc81..a7f9a43f 100644 --- a/util/flags.py +++ b/util/flags.py @@ -148,6 +148,7 @@ def create_flags(): f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') + f.DEFINE_float('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.') # Decoder From 5e1f54ae4fc70b05f188d04387b9ce288a80e230 Mon Sep 17 00:00:00 2001 From: Richard Hamnett Date: Fri, 21 Feb 2020 18:33:43 +0000 Subject: [PATCH 2/3] Reset learning rate if force set --- util/checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/checkpoints.py b/util/checkpoints.py index e2b3dad8..9139dbc9 100644 --- a/util/checkpoints.py +++ b/util/checkpoints.py @@ -18,7 +18,7 @@ def _load_checkpoint(session, checkpoint_path): # We explicitly allow the learning rate variable to be missing for backwards # compatibility with older checkpoints. lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') - if lr_var and 'learning_rate' not in vars_in_ckpt: + if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate): assert len(lr_var) <= 1 load_vars -= lr_var init_vars |= lr_var From a3268545ab0e81249848e27252e18e6c377b1e54 Mon Sep 17 00:00:00 2001 From: Richard Hamnett Date: Fri, 21 Feb 2020 19:24:13 +0000 Subject: [PATCH 3/3] Update flags.py change flag datatype to boolean --- util/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/flags.py b/util/flags.py index a7f9a43f..5057d76c 100644 --- a/util/flags.py +++ b/util/flags.py @@ -148,7 +148,7 @@ def create_flags(): f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') - f.DEFINE_float('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.') + f.DEFINE_boolean('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.') # Decoder