From 0979821324cbfce24e3162af9d5c0fea747e5ec0 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Fri, 29 Jun 2018 10:30:13 -0700 Subject: [PATCH] Add more helpful error messages when restoring from checkpoint fails. PiperOrigin-RevId: 202668227 --- .../slim/python/slim/evaluation_test.py | 3 +- tensorflow/python/estimator/estimator_test.py | 26 +++++++++ tensorflow/python/training/saver.py | 50 +++++++++++------ tensorflow/python/training/saver_test.py | 54 ++++++++++--------- 4 files changed, 90 insertions(+), 43 deletions(-) diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 3d0308aaf3d..2c978345234 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 733c7fb95dd..2a0e4e76175 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase): dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) self.assertEqual(5, scores['global_step']) + def test_wrong_shape_throws_reasonable_error(self): + """Make sure we are helpful when model_fns change. See b/110263146.""" + def _get_model_fn(val=1): + def _model_fn(features, labels, mode): + del features, labels # unused + variables.Variable(val, name='weight') + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + return _model_fn + + model_fn_1 = _get_model_fn() + model_fn_2 = _get_model_fn(val=[1]) + + est1 = estimator.Estimator(model_fn=model_fn_1) + est1.train(dummy_input_fn, steps=5) + est2 = estimator.Estimator( + model_fn=model_fn_2, model_dir=est1.model_dir) + + expected_msg = 'Restoring from checkpoint failed.*a mismatch between' + with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg): + est2.train(dummy_input_fn, steps=1,) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 53ed89e4ab8..1ee975fbe48 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import os.path import re -import sys import time import uuid @@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): @@ -1706,12 +1705,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1719,23 +1723,24 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. - should_reraise = False + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: reader = pywrap_tensorflow.NewCheckpointReader(save_path) object_graph_string = reader.get_tensor( checkpointable.OBJECT_GRAPH_PROTO_KEY) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception, but do it outside the except - # block so the object graph lookup isn't included in the stack trace. - should_reraise = True - if should_reraise: - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1747,6 +1752,11 @@ class Saver(object): self._restore_from_object_based_checkpoint( sess=sess, save_path=save_path, object_graph_string=object_graph_string) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") def _restore_from_object_based_checkpoint(self, sess, save_path, object_graph_string): @@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): return meta_graph_filename +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + ops.register_proto_function( ops.GraphKeys.SAVERS, proto_type=saver_pb2.SaverDef, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index f235300eb5c..ae9c244aaf3 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -24,10 +24,8 @@ import math import os import random import shutil -import sys import tempfile import time -import traceback import numpy as np import six @@ -369,8 +367,8 @@ class SaverTest(test.TestCase): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): with self.test_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) - with self.assertRaisesRegexp(errors.NotFoundError, - "Failed to find any matching files for"): + with self.assertRaisesRegexp( + ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") def testInt64(self): @@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase): errors.NotFoundError, "Key b not found in checkpoint"): b_saver.restore(sess=sess, save_path=save_path) - def testCheckpointNotFoundErrorRaised(self): - # Restore does some tricky exception handling to figure out if it should - # load an object-based checkpoint. Tests that the exception handling isn't - # too broad. - a = resource_variable_ops.ResourceVariable(1., name="a") - saver = saver_module.Saver([a]) - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.NotFoundError, - "Failed to find any matching files for path_which_does_not_exist"): - saver.restore(sess=sess, save_path="path_which_does_not_exist") - try: - saver.restore(sess=sess, save_path="path_which_does_not_exist") - except errors.NotFoundError: - # Make sure we don't have a confusing "During handling of the above - # exception" block in Python 3. - # pylint: disable=no-value-for-parameter - exception_string = "\n".join( - traceback.format_exception(*sys.exc_info())) - # pylint: enable=no-value-for-parameter - self.assertNotIn("NewCheckpointReader", exception_string) + with self.assertRaises(errors.NotFoundError) as cs: + b_saver.restore(sess=sess, save_path=save_path) + + # Make sure we don't have a confusing "During handling of the above + # exception" block in Python 3. + self.assertNotIn("NewCheckpointReader", cs.exception.message) + + def testGraphChangedForRestoreErrorRaised(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + with ops_lib.Graph().as_default() as g: + a = variables.Variable(1., name="a") + a_saver = saver_module.Saver([a]) + + with self.test_session(graph=g) as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + + with ops_lib.Graph().as_default() as g: + a = variables.Variable([1.], name="a") + a_saver = saver_module.Saver([a]) + with self.test_session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "a mismatch between the current graph and the graph"): + a_saver.restore(sess=sess, save_path=save_path) def testLoadFromObjectBasedGraph(self): checkpoint_directory = self.get_temp_dir()