Merge pull request #3286 from mozilla/test-pr-3268

Test PR #3268
This commit is contained in:
Reuben Morais 2020-08-27 20:02:48 +02:00 committed by GitHub
commit 1c9f3bc99d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 10 deletions

View File

@ -640,17 +640,23 @@ def train():
break
# Reduce learning rate on plateau
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
# Reload checkpoint that we use the best_dev weights again
reload_best_checkpoint(session)
# Reduce learning rate
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
# Reload checkpoint that we use the best_dev weights again
reload_best_checkpoint(session)
# Overwrite best checkpoint with new learning rate value
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking

View File

@ -6,7 +6,7 @@ from .flags import FLAGS
from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -18,7 +18,8 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
# We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate):
if lr_var and ('learning_rate' not in vars_in_ckpt or
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
assert len(lr_var) <= 1
load_vars -= lr_var
init_vars |= lr_var
@ -87,14 +88,14 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def _load_or_init_impl(session, method_order, allow_drop_layers):
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best':
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
log_info('Could not find best validating checkpoint.')
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
@ -102,7 +103,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
ckpt_path = _checkpoint_path_or_none('checkpoint')
if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
log_info('Could not find most recent checkpoint.')
# Initialize all variables
@ -119,7 +120,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
def reload_best_checkpoint(session):
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
_load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session):