Merge pull request #2781 from rhamnett/patch-2

Add flag to force reinitialisation of learning rate after lr_plateau
This commit is contained in:
Reuben Morais 2020-02-24 16:08:13 +01:00 committed by GitHub
commit 1f1f5a98e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 1 deletions

View File

@ -18,7 +18,7 @@ def _load_checkpoint(session, checkpoint_path):
# We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and 'learning_rate' not in vars_in_ckpt:
if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate):
assert len(lr_var) <= 1
load_vars -= lr_var
init_vars |= lr_var

View File

@ -148,6 +148,7 @@ def create_flags():
f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.')
f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping')
f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.')
f.DEFINE_boolean('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.')
# Decoder