Allow missing learning rate variable in older checkpoints

This commit is contained in:
Reuben Morais 2020-02-18 14:42:09 +01:00
parent 17ddc5600e
commit 6e12b7caed
1 changed files with 16 additions and 7 deletions

View File

@ -11,29 +11,38 @@ def _load_checkpoint(session, checkpoint_path):
# we will exclude variables we do not wish to load and then
# we will initialize them instead
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
load_vars = set(tfv1.global_variables())
init_vars = set()
# We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints.
if 'learning_rate' not in vars_in_ckpt:
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
assert len(lr_var) == 1
load_vars -= lr_var
init_vars |= lr_var
if FLAGS.load_cudnn:
# Initialize training from a CuDNN RNN checkpoint
# Identify the variables which we cannot load, and set them
# for initialization
missing_vars = set()
for v in load_vars:
try:
ckpt.get_tensor(v.op.name)
except tf.errors.NotFoundError:
log_error('CUDNN variable not found: %s' % (v.op.name))
if v.op.name not in vars_in_ckpt:
log_warn('CUDNN variable not found: %s' % (v.op.name))
missing_vars.add(v)
init_vars.add(v)
load_vars -= init_vars
# Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue
init_var_names = [v.op.name for v in init_vars]
if any('Adam' not in v for v in init_var_names):
missing_var_names = [v.op.name for v in missing_vars]
if any('Adam' not in v for v in missing_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors. Missing variables: {}'.format(init_var_names))
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)
if FLAGS.drop_source_layers > 0: