Add more helpful error messages when restoring from checkpoint fails.
PiperOrigin-RevId: 202668227
This commit is contained in:
parent
f139bc3f5c
commit
0979821324
@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.wrappers import hooks
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics
|
||||
@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
|
||||
checkpoint_path = os.path.join(self.get_temp_dir(),
|
||||
'this_file_doesnt_exist')
|
||||
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
with self.assertRaises(ValueError):
|
||||
evaluation.evaluate_once('', checkpoint_path, log_dir)
|
||||
|
||||
def _prepareCheckpoint(self, checkpoint_path):
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase):
|
||||
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
|
||||
self.assertEqual(5, scores['global_step'])
|
||||
|
||||
def test_wrong_shape_throws_reasonable_error(self):
|
||||
"""Make sure we are helpful when model_fns change. See b/110263146."""
|
||||
def _get_model_fn(val=1):
|
||||
def _model_fn(features, labels, mode):
|
||||
del features, labels # unused
|
||||
variables.Variable(val, name='weight')
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions=constant_op.constant([[1.]]),
|
||||
loss=constant_op.constant(0.),
|
||||
train_op=state_ops.assign_add(training.get_global_step(), 1))
|
||||
return _model_fn
|
||||
|
||||
model_fn_1 = _get_model_fn()
|
||||
model_fn_2 = _get_model_fn(val=[1])
|
||||
|
||||
est1 = estimator.Estimator(model_fn=model_fn_1)
|
||||
est1.train(dummy_input_fn, steps=5)
|
||||
est2 = estimator.Estimator(
|
||||
model_fn=model_fn_2, model_dir=est1.model_dir)
|
||||
|
||||
expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
|
||||
est2.train(dummy_input_fn, steps=1,)
|
||||
|
||||
def test_scaffold_is_used(self):
|
||||
|
||||
def _model_fn_scaffold(features, labels, mode):
|
||||
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
||||
import collections
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
|
||||
@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
if not ckpt.model_checkpoint_path:
|
||||
raise ValueError("Invalid checkpoint state loaded from %s",
|
||||
checkpoint_dir)
|
||||
raise ValueError("Invalid checkpoint state loaded from "
|
||||
+ checkpoint_dir)
|
||||
# For relative model_checkpoint_path and all_model_checkpoint_paths,
|
||||
# prepend checkpoint_dir.
|
||||
if not os.path.isabs(ckpt.model_checkpoint_path):
|
||||
@ -1706,12 +1705,17 @@ class Saver(object):
|
||||
save_path: Path where parameters were previously saved.
|
||||
|
||||
Raises:
|
||||
ValueError: If save_path is None.
|
||||
ValueError: If save_path is None or not a valid checkpoint.
|
||||
"""
|
||||
if self._is_empty:
|
||||
return
|
||||
if save_path is None:
|
||||
raise ValueError("Can't load save_path when it is None.")
|
||||
|
||||
if not checkpoint_exists(compat.as_text(save_path)):
|
||||
raise ValueError("The passed save_path is not a valid checkpoint: "
|
||||
+ compat.as_text(save_path))
|
||||
|
||||
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
||||
try:
|
||||
if context.executing_eagerly():
|
||||
@ -1719,23 +1723,24 @@ class Saver(object):
|
||||
else:
|
||||
sess.run(self.saver_def.restore_op_name,
|
||||
{self.saver_def.filename_tensor_name: save_path})
|
||||
except errors.NotFoundError:
|
||||
exception_type, exception_value, exception_traceback = sys.exc_info()
|
||||
# The checkpoint would not be loaded successfully as is. Try to parse it
|
||||
# as an object-based checkpoint.
|
||||
should_reraise = False
|
||||
except errors.NotFoundError as err:
|
||||
# There are three common conditions that might cause this error:
|
||||
# 0. The file is missing. We ignore here, as this is checked above.
|
||||
# 1. This is an object-based checkpoint trying name-based loading.
|
||||
# 2. The graph has been altered and a variable or other name is missing.
|
||||
|
||||
# 1. The checkpoint would not be loaded successfully as is. Try to parse
|
||||
# it as an object-based checkpoint.
|
||||
try:
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
object_graph_string = reader.get_tensor(
|
||||
checkpointable.OBJECT_GRAPH_PROTO_KEY)
|
||||
except errors.NotFoundError:
|
||||
# This is not an object-based checkpoint, or the checkpoint doesn't
|
||||
# exist. Re-raise the original exception, but do it outside the except
|
||||
# block so the object graph lookup isn't included in the stack trace.
|
||||
should_reraise = True
|
||||
if should_reraise:
|
||||
six.reraise(exception_type, exception_value, exception_traceback)
|
||||
del exception_traceback # avoid reference cycles
|
||||
# 2. This is not an object-based checkpoint, which likely means there
|
||||
# is a graph mismatch. Re-raise the original error with
|
||||
# a helpful message (b/110263146)
|
||||
raise _wrap_restore_error_with_msg(
|
||||
err, "a Variable name or other graph key that is missing")
|
||||
|
||||
# This is an object-based checkpoint. We'll print a warning and then do
|
||||
# the restore.
|
||||
@ -1747,6 +1752,11 @@ class Saver(object):
|
||||
self._restore_from_object_based_checkpoint(
|
||||
sess=sess, save_path=save_path,
|
||||
object_graph_string=object_graph_string)
|
||||
except errors.InvalidArgumentError as err:
|
||||
# There is a mismatch between the graph and the checkpoint being loaded.
|
||||
# We add a more reasonable error message here to help users (b/110263146)
|
||||
raise _wrap_restore_error_with_msg(
|
||||
err, "a mismatch between the current graph and the graph")
|
||||
|
||||
def _restore_from_object_based_checkpoint(self, sess, save_path,
|
||||
object_graph_string):
|
||||
@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
||||
return meta_graph_filename
|
||||
|
||||
|
||||
def _wrap_restore_error_with_msg(err, extra_verbiage):
|
||||
err_msg = ("Restoring from checkpoint failed. This is most likely "
|
||||
"due to {} from the checkpoint. Please ensure that you "
|
||||
"have not altered the graph expected based on the checkpoint. "
|
||||
"Original error:\n\n{}").format(extra_verbiage, err.message)
|
||||
return err.__class__(err.node_def, err.op, err_msg)
|
||||
|
||||
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.SAVERS,
|
||||
proto_type=saver_pb2.SaverDef,
|
||||
|
@ -24,10 +24,8 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -369,8 +367,8 @@ class SaverTest(test.TestCase):
|
||||
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
|
||||
with self.test_session() as sess:
|
||||
save = saver_module.Saver({"v0": v0}, write_version=ver)
|
||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
||||
"Failed to find any matching files for"):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "The passed save_path is not a valid checkpoint:"):
|
||||
save.restore(sess, "invalid path")
|
||||
|
||||
def testInt64(self):
|
||||
@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
||||
errors.NotFoundError, "Key b not found in checkpoint"):
|
||||
b_saver.restore(sess=sess, save_path=save_path)
|
||||
|
||||
def testCheckpointNotFoundErrorRaised(self):
|
||||
# Restore does some tricky exception handling to figure out if it should
|
||||
# load an object-based checkpoint. Tests that the exception handling isn't
|
||||
# too broad.
|
||||
a = resource_variable_ops.ResourceVariable(1., name="a")
|
||||
saver = saver_module.Saver([a])
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
errors.NotFoundError,
|
||||
"Failed to find any matching files for path_which_does_not_exist"):
|
||||
saver.restore(sess=sess, save_path="path_which_does_not_exist")
|
||||
try:
|
||||
saver.restore(sess=sess, save_path="path_which_does_not_exist")
|
||||
except errors.NotFoundError:
|
||||
# Make sure we don't have a confusing "During handling of the above
|
||||
# exception" block in Python 3.
|
||||
# pylint: disable=no-value-for-parameter
|
||||
exception_string = "\n".join(
|
||||
traceback.format_exception(*sys.exc_info()))
|
||||
# pylint: enable=no-value-for-parameter
|
||||
self.assertNotIn("NewCheckpointReader", exception_string)
|
||||
with self.assertRaises(errors.NotFoundError) as cs:
|
||||
b_saver.restore(sess=sess, save_path=save_path)
|
||||
|
||||
# Make sure we don't have a confusing "During handling of the above
|
||||
# exception" block in Python 3.
|
||||
self.assertNotIn("NewCheckpointReader", cs.exception.message)
|
||||
|
||||
def testGraphChangedForRestoreErrorRaised(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
|
||||
with ops_lib.Graph().as_default() as g:
|
||||
a = variables.Variable(1., name="a")
|
||||
a_saver = saver_module.Saver([a])
|
||||
|
||||
with self.test_session(graph=g) as sess:
|
||||
sess.run(a.initializer)
|
||||
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
|
||||
|
||||
with ops_lib.Graph().as_default() as g:
|
||||
a = variables.Variable([1.], name="a")
|
||||
a_saver = saver_module.Saver([a])
|
||||
with self.test_session(graph=g) as sess:
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"a mismatch between the current graph and the graph"):
|
||||
a_saver.restore(sess=sess, save_path=save_path)
|
||||
|
||||
def testLoadFromObjectBasedGraph(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user