Add more helpful error messages when restoring from checkpoint fails.

PiperOrigin-RevId: 202668227
This commit is contained in:
Karmel Allison 2018-06-29 10:30:13 -07:00 committed by TensorFlower Gardener
parent f139bc3f5c
commit 0979821324
4 changed files with 90 additions and 43 deletions

View File

@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
checkpoint_path = os.path.join(self.get_temp_dir(),
'this_file_doesnt_exist')
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
with self.assertRaises(errors.NotFoundError):
with self.assertRaises(ValueError):
evaluation.evaluate_once('', checkpoint_path, log_dir)
def _prepareCheckpoint(self, checkpoint_path):

View File

@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase):
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
self.assertEqual(5, scores['global_step'])
def test_wrong_shape_throws_reasonable_error(self):
"""Make sure we are helpful when model_fns change. See b/110263146."""
def _get_model_fn(val=1):
def _model_fn(features, labels, mode):
del features, labels # unused
variables.Variable(val, name='weight')
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
loss=constant_op.constant(0.),
train_op=state_ops.assign_add(training.get_global_step(), 1))
return _model_fn
model_fn_1 = _get_model_fn()
model_fn_2 = _get_model_fn(val=[1])
est1 = estimator.Estimator(model_fn=model_fn_1)
est1.train(dummy_input_fn, steps=5)
est2 = estimator.Estimator(
model_fn=model_fn_2, model_dir=est1.model_dir)
expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
est2.train(dummy_input_fn, steps=1,)
def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode):

View File

@ -22,7 +22,6 @@ from __future__ import print_function
import collections
import os.path
import re
import sys
import time
import uuid
@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from %s",
checkpoint_dir)
raise ValueError("Invalid checkpoint state loaded from "
+ checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
@ -1706,12 +1705,17 @@ class Saver(object):
save_path: Path where parameters were previously saved.
Raises:
ValueError: If save_path is None.
ValueError: If save_path is None or not a valid checkpoint.
"""
if self._is_empty:
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
if not checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
logging.info("Restoring parameters from %s", compat.as_text(save_path))
try:
if context.executing_eagerly():
@ -1719,23 +1723,24 @@ class Saver(object):
else:
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
except errors.NotFoundError:
exception_type, exception_value, exception_traceback = sys.exc_info()
# The checkpoint would not be loaded successfully as is. Try to parse it
# as an object-based checkpoint.
should_reraise = False
except errors.NotFoundError as err:
# There are three common conditions that might cause this error:
# 0. The file is missing. We ignore here, as this is checked above.
# 1. This is an object-based checkpoint trying name-based loading.
# 2. The graph has been altered and a variable or other name is missing.
# 1. The checkpoint would not be loaded successfully as is. Try to parse
# it as an object-based checkpoint.
try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError:
# This is not an object-based checkpoint, or the checkpoint doesn't
# exist. Re-raise the original exception, but do it outside the except
# block so the object graph lookup isn't included in the stack trace.
should_reraise = True
if should_reraise:
six.reraise(exception_type, exception_value, exception_traceback)
del exception_traceback # avoid reference cycles
# 2. This is not an object-based checkpoint, which likely means there
# is a graph mismatch. Re-raise the original error with
# a helpful message (b/110263146)
raise _wrap_restore_error_with_msg(
err, "a Variable name or other graph key that is missing")
# This is an object-based checkpoint. We'll print a warning and then do
# the restore.
@ -1747,6 +1752,11 @@ class Saver(object):
self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path,
object_graph_string=object_graph_string)
except errors.InvalidArgumentError as err:
# There is a mismatch between the graph and the checkpoint being loaded.
# We add a more reasonable error message here to help users (b/110263146)
raise _wrap_restore_error_with_msg(
err, "a mismatch between the current graph and the graph")
def _restore_from_object_based_checkpoint(self, sess, save_path,
object_graph_string):
@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
return meta_graph_filename
def _wrap_restore_error_with_msg(err, extra_verbiage):
err_msg = ("Restoring from checkpoint failed. This is most likely "
"due to {} from the checkpoint. Please ensure that you "
"have not altered the graph expected based on the checkpoint. "
"Original error:\n\n{}").format(extra_verbiage, err.message)
return err.__class__(err.node_def, err.op, err_msg)
ops.register_proto_function(
ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef,

View File

@ -24,10 +24,8 @@ import math
import os
import random
import shutil
import sys
import tempfile
import time
import traceback
import numpy as np
import six
@ -369,8 +367,8 @@ class SaverTest(test.TestCase):
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.test_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
with self.assertRaisesRegexp(errors.NotFoundError,
"Failed to find any matching files for"):
with self.assertRaisesRegexp(
ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path")
def testInt64(self):
@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError, "Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path)
def testCheckpointNotFoundErrorRaised(self):
# Restore does some tricky exception handling to figure out if it should
# load an object-based checkpoint. Tests that the exception handling isn't
# too broad.
a = resource_variable_ops.ResourceVariable(1., name="a")
saver = saver_module.Saver([a])
with self.test_session() as sess:
with self.assertRaisesRegexp(
errors.NotFoundError,
"Failed to find any matching files for path_which_does_not_exist"):
saver.restore(sess=sess, save_path="path_which_does_not_exist")
try:
saver.restore(sess=sess, save_path="path_which_does_not_exist")
except errors.NotFoundError:
# Make sure we don't have a confusing "During handling of the above
# exception" block in Python 3.
# pylint: disable=no-value-for-parameter
exception_string = "\n".join(
traceback.format_exception(*sys.exc_info()))
# pylint: enable=no-value-for-parameter
self.assertNotIn("NewCheckpointReader", exception_string)
with self.assertRaises(errors.NotFoundError) as cs:
b_saver.restore(sess=sess, save_path=save_path)
# Make sure we don't have a confusing "During handling of the above
# exception" block in Python 3.
self.assertNotIn("NewCheckpointReader", cs.exception.message)
def testGraphChangedForRestoreErrorRaised(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops_lib.Graph().as_default() as g:
a = variables.Variable(1., name="a")
a_saver = saver_module.Saver([a])
with self.test_session(graph=g) as sess:
sess.run(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
a = variables.Variable([1.], name="a")
a_saver = saver_module.Saver([a])
with self.test_session(graph=g) as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"a mismatch between the current graph and the graph"):
a_saver.restore(sess=sess, save_path=save_path)
def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir()