Handle graph without learning rate variable for export case

This commit is contained in:
Reuben Morais 2020-02-20 15:40:35 +01:00
parent 234a64c6ea
commit 4291db7309
1 changed files with 3 additions and 3 deletions

View File

@ -17,9 +17,9 @@ def _load_checkpoint(session, checkpoint_path):
# We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints.
if 'learning_rate' not in vars_in_ckpt:
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
assert len(lr_var) == 1
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and 'learning_rate' not in vars_in_ckpt:
assert len(lr_var) <= 1
load_vars -= lr_var
init_vars |= lr_var