diff --git a/util/checkpoints.py b/util/checkpoints.py index 94b1273e..e2b3dad8 100644 --- a/util/checkpoints.py +++ b/util/checkpoints.py @@ -17,9 +17,9 @@ def _load_checkpoint(session, checkpoint_path): # We explicitly allow the learning rate variable to be missing for backwards # compatibility with older checkpoints. - if 'learning_rate' not in vars_in_ckpt: - lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') - assert len(lr_var) == 1 + lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') + if lr_var and 'learning_rate' not in vars_in_ckpt: + assert len(lr_var) <= 1 load_vars -= lr_var init_vars |= lr_var