Handle graph without learning rate variable for export case

This commit is contained in:
Reuben Morais 2020-02-20 15:40:35 +01:00
parent 234a64c6ea
commit 4291db7309
1 changed files with 3 additions and 3 deletions

View File

@ -17,9 +17,9 @@ def _load_checkpoint(session, checkpoint_path):
# We explicitly allow the learning rate variable to be missing for backwards # We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints. # compatibility with older checkpoints.
if 'learning_rate' not in vars_in_ckpt: lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') if lr_var and 'learning_rate' not in vars_in_ckpt:
assert len(lr_var) == 1 assert len(lr_var) <= 1
load_vars -= lr_var load_vars -= lr_var
init_vars |= lr_var init_vars |= lr_var