Allow missing learning rate variable in older checkpoints

This commit is contained in:
Reuben Morais 2020-02-18 14:42:09 +01:00
parent 17ddc5600e
commit 6e12b7caed
1 changed files with 16 additions and 7 deletions

View File

@ -11,29 +11,38 @@ def _load_checkpoint(session, checkpoint_path):
# we will exclude variables we do not wish to load and then # we will exclude variables we do not wish to load and then
# we will initialize them instead # we will initialize them instead
ckpt = tfv1.train.load_checkpoint(checkpoint_path) ckpt = tfv1.train.load_checkpoint(checkpoint_path)
vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
load_vars = set(tfv1.global_variables()) load_vars = set(tfv1.global_variables())
init_vars = set() init_vars = set()
# We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints.
if 'learning_rate' not in vars_in_ckpt:
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
assert len(lr_var) == 1
load_vars -= lr_var
init_vars |= lr_var
if FLAGS.load_cudnn: if FLAGS.load_cudnn:
# Initialize training from a CuDNN RNN checkpoint # Initialize training from a CuDNN RNN checkpoint
# Identify the variables which we cannot load, and set them # Identify the variables which we cannot load, and set them
# for initialization # for initialization
missing_vars = set()
for v in load_vars: for v in load_vars:
try: if v.op.name not in vars_in_ckpt:
ckpt.get_tensor(v.op.name) log_warn('CUDNN variable not found: %s' % (v.op.name))
except tf.errors.NotFoundError: missing_vars.add(v)
log_error('CUDNN variable not found: %s' % (v.op.name))
init_vars.add(v) init_vars.add(v)
load_vars -= init_vars load_vars -= init_vars
# Check that the only missing variables (i.e. those to be initialised) # Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue # are the Adam moment tensors, if they aren't then we have an issue
init_var_names = [v.op.name for v in init_vars] missing_var_names = [v.op.name for v in missing_vars]
if any('Adam' not in v for v in init_var_names): if any('Adam' not in v for v in missing_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were ' log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment ' 'more missing variables than just the Adam moment '
'tensors. Missing variables: {}'.format(init_var_names)) 'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1) sys.exit(1)
if FLAGS.drop_source_layers > 0: if FLAGS.drop_source_layers > 0: