Fix lr initialization on reload.
This commit is contained in:
parent
8965b29e81
commit
93a4de5489
@ -640,17 +640,23 @@ def train():
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Reduce learning rate on plateau
|
# Reduce learning rate on plateau
|
||||||
if (FLAGS.reduce_lr_on_plateau and
|
|
||||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
|
||||||
# If the learning rate was reduced and there is still no improvement
|
# If the learning rate was reduced and there is still no improvement
|
||||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||||
|
if (FLAGS.reduce_lr_on_plateau and
|
||||||
|
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||||
|
|
||||||
|
# Reload checkpoint that we use the best_dev weights again
|
||||||
|
reload_best_checkpoint(session)
|
||||||
|
|
||||||
|
# Reduce learning rate
|
||||||
session.run(reduce_learning_rate_op)
|
session.run(reduce_learning_rate_op)
|
||||||
current_learning_rate = learning_rate_var.eval()
|
current_learning_rate = learning_rate_var.eval()
|
||||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||||
current_learning_rate))
|
current_learning_rate))
|
||||||
|
|
||||||
# Reload checkpoint that we use the best_dev weights again
|
# Overwrite best checkpoint with new learning rate value
|
||||||
reload_best_checkpoint(session)
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||||
|
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
|
||||||
|
|
||||||
if FLAGS.metrics_files:
|
if FLAGS.metrics_files:
|
||||||
# Read only metrics, not affecting best validation loss tracking
|
# Read only metrics, not affecting best validation loss tracking
|
||||||
|
@ -6,7 +6,7 @@ from .flags import FLAGS
|
|||||||
from .logging import log_info, log_error, log_warn
|
from .logging import log_info, log_error, log_warn
|
||||||
|
|
||||||
|
|
||||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
||||||
# Load the checkpoint and put all variables into loading list
|
# Load the checkpoint and put all variables into loading list
|
||||||
# we will exclude variables we do not wish to load and then
|
# we will exclude variables we do not wish to load and then
|
||||||
# we will initialize them instead
|
# we will initialize them instead
|
||||||
@ -18,7 +18,8 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
|||||||
# We explicitly allow the learning rate variable to be missing for backwards
|
# We explicitly allow the learning rate variable to be missing for backwards
|
||||||
# compatibility with older checkpoints.
|
# compatibility with older checkpoints.
|
||||||
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
|
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
|
||||||
if lr_var and ('learning_rate' not in vars_in_ckpt or FLAGS.force_initialize_learning_rate):
|
if lr_var and ('learning_rate' not in vars_in_ckpt or
|
||||||
|
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
|
||||||
assert len(lr_var) <= 1
|
assert len(lr_var) <= 1
|
||||||
load_vars -= lr_var
|
load_vars -= lr_var
|
||||||
init_vars |= lr_var
|
init_vars |= lr_var
|
||||||
@ -87,14 +88,14 @@ def _initialize_all_variables(session):
|
|||||||
session.run(v.initializer)
|
session.run(v.initializer)
|
||||||
|
|
||||||
|
|
||||||
def _load_or_init_impl(session, method_order, allow_drop_layers):
|
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
|
||||||
for method in method_order:
|
for method in method_order:
|
||||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||||
if method == 'best':
|
if method == 'best':
|
||||||
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
|
||||||
log_info('Could not find best validating checkpoint.')
|
log_info('Could not find best validating checkpoint.')
|
||||||
|
|
||||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||||
@ -102,7 +103,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
|
|||||||
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||||
if ckpt_path:
|
if ckpt_path:
|
||||||
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init)
|
||||||
log_info('Could not find most recent checkpoint.')
|
log_info('Could not find most recent checkpoint.')
|
||||||
|
|
||||||
# Initialize all variables
|
# Initialize all variables
|
||||||
@ -119,7 +120,7 @@ def _load_or_init_impl(session, method_order, allow_drop_layers):
|
|||||||
|
|
||||||
|
|
||||||
def reload_best_checkpoint(session):
|
def reload_best_checkpoint(session):
|
||||||
_load_or_init_impl(session, ['best'], allow_drop_layers=False)
|
_load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False)
|
||||||
|
|
||||||
|
|
||||||
def load_or_init_graph_for_training(session):
|
def load_or_init_graph_for_training(session):
|
||||||
|
Loading…
Reference in New Issue
Block a user