Allow missing learning rate variable in older checkpoints
This commit is contained in:
parent
17ddc5600e
commit
6e12b7caed
|
@ -11,29 +11,38 @@ def _load_checkpoint(session, checkpoint_path):
|
|||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
||||
vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
|
||||
load_vars = set(tfv1.global_variables())
|
||||
init_vars = set()
|
||||
|
||||
# We explicitly allow the learning rate variable to be missing for backwards
|
||||
# compatibility with older checkpoints.
|
||||
if 'learning_rate' not in vars_in_ckpt:
|
||||
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
|
||||
assert len(lr_var) == 1
|
||||
load_vars -= lr_var
|
||||
init_vars |= lr_var
|
||||
|
||||
if FLAGS.load_cudnn:
|
||||
# Initialize training from a CuDNN RNN checkpoint
|
||||
# Identify the variables which we cannot load, and set them
|
||||
# for initialization
|
||||
missing_vars = set()
|
||||
for v in load_vars:
|
||||
try:
|
||||
ckpt.get_tensor(v.op.name)
|
||||
except tf.errors.NotFoundError:
|
||||
log_error('CUDNN variable not found: %s' % (v.op.name))
|
||||
if v.op.name not in vars_in_ckpt:
|
||||
log_warn('CUDNN variable not found: %s' % (v.op.name))
|
||||
missing_vars.add(v)
|
||||
init_vars.add(v)
|
||||
|
||||
load_vars -= init_vars
|
||||
|
||||
# Check that the only missing variables (i.e. those to be initialised)
|
||||
# are the Adam moment tensors, if they aren't then we have an issue
|
||||
init_var_names = [v.op.name for v in init_vars]
|
||||
if any('Adam' not in v for v in init_var_names):
|
||||
missing_var_names = [v.op.name for v in missing_vars]
|
||||
if any('Adam' not in v for v in missing_var_names):
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors. Missing variables: {}'.format(init_var_names))
|
||||
'tensors. Missing variables: {}'.format(missing_var_names))
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.drop_source_layers > 0:
|
||||
|
|
Loading…
Reference in New Issue