Add more helpful error messages when restoring from checkpoint fails.

PiperOrigin-RevId: 202668227
This commit is contained in:
Karmel Allison 2018-06-29 10:30:13 -07:00 committed by TensorFlower Gardener
parent f139bc3f5c
commit 0979821324
4 changed files with 90 additions and 43 deletions

View File

@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics from tensorflow.python.ops import metrics
@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
checkpoint_path = os.path.join(self.get_temp_dir(), checkpoint_path = os.path.join(self.get_temp_dir(),
'this_file_doesnt_exist') 'this_file_doesnt_exist')
log_dir = os.path.join(self.get_temp_dir(), 'error_raised') log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
with self.assertRaises(errors.NotFoundError): with self.assertRaises(ValueError):
evaluation.evaluate_once('', checkpoint_path, log_dir) evaluation.evaluate_once('', checkpoint_path, log_dir)
def _prepareCheckpoint(self, checkpoint_path): def _prepareCheckpoint(self, checkpoint_path):

View File

@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase):
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
self.assertEqual(5, scores['global_step']) self.assertEqual(5, scores['global_step'])
def test_wrong_shape_throws_reasonable_error(self):
"""Make sure we are helpful when model_fns change. See b/110263146."""
def _get_model_fn(val=1):
def _model_fn(features, labels, mode):
del features, labels # unused
variables.Variable(val, name='weight')
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
loss=constant_op.constant(0.),
train_op=state_ops.assign_add(training.get_global_step(), 1))
return _model_fn
model_fn_1 = _get_model_fn()
model_fn_2 = _get_model_fn(val=[1])
est1 = estimator.Estimator(model_fn=model_fn_1)
est1.train(dummy_input_fn, steps=5)
est2 = estimator.Estimator(
model_fn=model_fn_2, model_dir=est1.model_dir)
expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
est2.train(dummy_input_fn, steps=1,)
def test_scaffold_is_used(self): def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode): def _model_fn_scaffold(features, labels, mode):

View File

@ -22,7 +22,6 @@ from __future__ import print_function
import collections import collections
import os.path import os.path
import re import re
import sys
import time import time
import uuid import uuid
@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt = CheckpointState() ckpt = CheckpointState()
text_format.Merge(file_content, ckpt) text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path: if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from %s", raise ValueError("Invalid checkpoint state loaded from "
checkpoint_dir) + checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths, # For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir. # prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path): if not os.path.isabs(ckpt.model_checkpoint_path):
@ -1706,12 +1705,17 @@ class Saver(object):
save_path: Path where parameters were previously saved. save_path: Path where parameters were previously saved.
Raises: Raises:
ValueError: If save_path is None. ValueError: If save_path is None or not a valid checkpoint.
""" """
if self._is_empty: if self._is_empty:
return return
if save_path is None: if save_path is None:
raise ValueError("Can't load save_path when it is None.") raise ValueError("Can't load save_path when it is None.")
if not checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
logging.info("Restoring parameters from %s", compat.as_text(save_path)) logging.info("Restoring parameters from %s", compat.as_text(save_path))
try: try:
if context.executing_eagerly(): if context.executing_eagerly():
@ -1719,23 +1723,24 @@ class Saver(object):
else: else:
sess.run(self.saver_def.restore_op_name, sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path}) {self.saver_def.filename_tensor_name: save_path})
except errors.NotFoundError: except errors.NotFoundError as err:
exception_type, exception_value, exception_traceback = sys.exc_info() # There are three common conditions that might cause this error:
# The checkpoint would not be loaded successfully as is. Try to parse it # 0. The file is missing. We ignore here, as this is checked above.
# as an object-based checkpoint. # 1. This is an object-based checkpoint trying name-based loading.
should_reraise = False # 2. The graph has been altered and a variable or other name is missing.
# 1. The checkpoint would not be loaded successfully as is. Try to parse
# it as an object-based checkpoint.
try: try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path) reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor( object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY) checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError: except errors.NotFoundError:
# This is not an object-based checkpoint, or the checkpoint doesn't # 2. This is not an object-based checkpoint, which likely means there
# exist. Re-raise the original exception, but do it outside the except # is a graph mismatch. Re-raise the original error with
# block so the object graph lookup isn't included in the stack trace. # a helpful message (b/110263146)
should_reraise = True raise _wrap_restore_error_with_msg(
if should_reraise: err, "a Variable name or other graph key that is missing")
six.reraise(exception_type, exception_value, exception_traceback)
del exception_traceback # avoid reference cycles
# This is an object-based checkpoint. We'll print a warning and then do # This is an object-based checkpoint. We'll print a warning and then do
# the restore. # the restore.
@ -1747,6 +1752,11 @@ class Saver(object):
self._restore_from_object_based_checkpoint( self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path, sess=sess, save_path=save_path,
object_graph_string=object_graph_string) object_graph_string=object_graph_string)
except errors.InvalidArgumentError as err:
# There is a mismatch between the graph and the checkpoint being loaded.
# We add a more reasonable error message here to help users (b/110263146)
raise _wrap_restore_error_with_msg(
err, "a mismatch between the current graph and the graph")
def _restore_from_object_based_checkpoint(self, sess, save_path, def _restore_from_object_based_checkpoint(self, sess, save_path,
object_graph_string): object_graph_string):
@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
return meta_graph_filename return meta_graph_filename
def _wrap_restore_error_with_msg(err, extra_verbiage):
err_msg = ("Restoring from checkpoint failed. This is most likely "
"due to {} from the checkpoint. Please ensure that you "
"have not altered the graph expected based on the checkpoint. "
"Original error:\n\n{}").format(extra_verbiage, err.message)
return err.__class__(err.node_def, err.op, err_msg)
ops.register_proto_function( ops.register_proto_function(
ops.GraphKeys.SAVERS, ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef, proto_type=saver_pb2.SaverDef,

View File

@ -24,10 +24,8 @@ import math
import os import os
import random import random
import shutil import shutil
import sys
import tempfile import tempfile
import time import time
import traceback
import numpy as np import numpy as np
import six import six
@ -369,8 +367,8 @@ class SaverTest(test.TestCase):
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.test_session() as sess: with self.test_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver) save = saver_module.Saver({"v0": v0}, write_version=ver)
with self.assertRaisesRegexp(errors.NotFoundError, with self.assertRaisesRegexp(
"Failed to find any matching files for"): ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path") save.restore(sess, "invalid path")
def testInt64(self): def testInt64(self):
@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError, "Key b not found in checkpoint"): errors.NotFoundError, "Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path) b_saver.restore(sess=sess, save_path=save_path)
def testCheckpointNotFoundErrorRaised(self): with self.assertRaises(errors.NotFoundError) as cs:
# Restore does some tricky exception handling to figure out if it should b_saver.restore(sess=sess, save_path=save_path)
# load an object-based checkpoint. Tests that the exception handling isn't
# too broad. # Make sure we don't have a confusing "During handling of the above
a = resource_variable_ops.ResourceVariable(1., name="a") # exception" block in Python 3.
saver = saver_module.Saver([a]) self.assertNotIn("NewCheckpointReader", cs.exception.message)
with self.test_session() as sess:
with self.assertRaisesRegexp( def testGraphChangedForRestoreErrorRaised(self):
errors.NotFoundError, checkpoint_directory = self.get_temp_dir()
"Failed to find any matching files for path_which_does_not_exist"): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
saver.restore(sess=sess, save_path="path_which_does_not_exist")
try: with ops_lib.Graph().as_default() as g:
saver.restore(sess=sess, save_path="path_which_does_not_exist") a = variables.Variable(1., name="a")
except errors.NotFoundError: a_saver = saver_module.Saver([a])
# Make sure we don't have a confusing "During handling of the above
# exception" block in Python 3. with self.test_session(graph=g) as sess:
# pylint: disable=no-value-for-parameter sess.run(a.initializer)
exception_string = "\n".join( save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
traceback.format_exception(*sys.exc_info()))
# pylint: enable=no-value-for-parameter with ops_lib.Graph().as_default() as g:
self.assertNotIn("NewCheckpointReader", exception_string) a = variables.Variable([1.], name="a")
a_saver = saver_module.Saver([a])
with self.test_session(graph=g) as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"a mismatch between the current graph and the graph"):
a_saver.restore(sess=sess, save_path=save_path)
def testLoadFromObjectBasedGraph(self): def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir() checkpoint_directory = self.get_temp_dir()