Allow missing learning rate variable in older checkpoints
This commit is contained in:
parent
17ddc5600e
commit
6e12b7caed
|
@ -11,29 +11,38 @@ def _load_checkpoint(session, checkpoint_path):
|
||||||
# we will exclude variables we do not wish to load and then
|
# we will exclude variables we do not wish to load and then
|
||||||
# we will initialize them instead
|
# we will initialize them instead
|
||||||
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
||||||
|
vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys())
|
||||||
load_vars = set(tfv1.global_variables())
|
load_vars = set(tfv1.global_variables())
|
||||||
init_vars = set()
|
init_vars = set()
|
||||||
|
|
||||||
|
# We explicitly allow the learning rate variable to be missing for backwards
|
||||||
|
# compatibility with older checkpoints.
|
||||||
|
if 'learning_rate' not in vars_in_ckpt:
|
||||||
|
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
|
||||||
|
assert len(lr_var) == 1
|
||||||
|
load_vars -= lr_var
|
||||||
|
init_vars |= lr_var
|
||||||
|
|
||||||
if FLAGS.load_cudnn:
|
if FLAGS.load_cudnn:
|
||||||
# Initialize training from a CuDNN RNN checkpoint
|
# Initialize training from a CuDNN RNN checkpoint
|
||||||
# Identify the variables which we cannot load, and set them
|
# Identify the variables which we cannot load, and set them
|
||||||
# for initialization
|
# for initialization
|
||||||
|
missing_vars = set()
|
||||||
for v in load_vars:
|
for v in load_vars:
|
||||||
try:
|
if v.op.name not in vars_in_ckpt:
|
||||||
ckpt.get_tensor(v.op.name)
|
log_warn('CUDNN variable not found: %s' % (v.op.name))
|
||||||
except tf.errors.NotFoundError:
|
missing_vars.add(v)
|
||||||
log_error('CUDNN variable not found: %s' % (v.op.name))
|
|
||||||
init_vars.add(v)
|
init_vars.add(v)
|
||||||
|
|
||||||
load_vars -= init_vars
|
load_vars -= init_vars
|
||||||
|
|
||||||
# Check that the only missing variables (i.e. those to be initialised)
|
# Check that the only missing variables (i.e. those to be initialised)
|
||||||
# are the Adam moment tensors, if they aren't then we have an issue
|
# are the Adam moment tensors, if they aren't then we have an issue
|
||||||
init_var_names = [v.op.name for v in init_vars]
|
missing_var_names = [v.op.name for v in missing_vars]
|
||||||
if any('Adam' not in v for v in init_var_names):
|
if any('Adam' not in v for v in missing_var_names):
|
||||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||||
'more missing variables than just the Adam moment '
|
'more missing variables than just the Adam moment '
|
||||||
'tensors. Missing variables: {}'.format(init_var_names))
|
'tensors. Missing variables: {}'.format(missing_var_names))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if FLAGS.drop_source_layers > 0:
|
if FLAGS.drop_source_layers > 0:
|
||||||
|
|
Loading…
Reference in New Issue