Add more helpful error messages when restoring from checkpoint fails.
PiperOrigin-RevId: 202668227
This commit is contained in:
parent
f139bc3f5c
commit
0979821324
@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data
|
|||||||
from tensorflow.python.debug.wrappers import hooks
|
from tensorflow.python.debug.wrappers import hooks
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import metrics
|
from tensorflow.python.ops import metrics
|
||||||
@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
|
|||||||
checkpoint_path = os.path.join(self.get_temp_dir(),
|
checkpoint_path = os.path.join(self.get_temp_dir(),
|
||||||
'this_file_doesnt_exist')
|
'this_file_doesnt_exist')
|
||||||
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
|
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
|
||||||
with self.assertRaises(errors.NotFoundError):
|
with self.assertRaises(ValueError):
|
||||||
evaluation.evaluate_once('', checkpoint_path, log_dir)
|
evaluation.evaluate_once('', checkpoint_path, log_dir)
|
||||||
|
|
||||||
def _prepareCheckpoint(self, checkpoint_path):
|
def _prepareCheckpoint(self, checkpoint_path):
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
|
|||||||
from tensorflow.python.estimator.inputs import numpy_io
|
from tensorflow.python.estimator.inputs import numpy_io
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase):
|
|||||||
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
|
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
|
||||||
self.assertEqual(5, scores['global_step'])
|
self.assertEqual(5, scores['global_step'])
|
||||||
|
|
||||||
|
def test_wrong_shape_throws_reasonable_error(self):
|
||||||
|
"""Make sure we are helpful when model_fns change. See b/110263146."""
|
||||||
|
def _get_model_fn(val=1):
|
||||||
|
def _model_fn(features, labels, mode):
|
||||||
|
del features, labels # unused
|
||||||
|
variables.Variable(val, name='weight')
|
||||||
|
return model_fn_lib.EstimatorSpec(
|
||||||
|
mode=mode,
|
||||||
|
predictions=constant_op.constant([[1.]]),
|
||||||
|
loss=constant_op.constant(0.),
|
||||||
|
train_op=state_ops.assign_add(training.get_global_step(), 1))
|
||||||
|
return _model_fn
|
||||||
|
|
||||||
|
model_fn_1 = _get_model_fn()
|
||||||
|
model_fn_2 = _get_model_fn(val=[1])
|
||||||
|
|
||||||
|
est1 = estimator.Estimator(model_fn=model_fn_1)
|
||||||
|
est1.train(dummy_input_fn, steps=5)
|
||||||
|
est2 = estimator.Estimator(
|
||||||
|
model_fn=model_fn_2, model_dir=est1.model_dir)
|
||||||
|
|
||||||
|
expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
|
||||||
|
est2.train(dummy_input_fn, steps=1,)
|
||||||
|
|
||||||
def test_scaffold_is_used(self):
|
def test_scaffold_is_used(self):
|
||||||
|
|
||||||
def _model_fn_scaffold(features, labels, mode):
|
def _model_fn_scaffold(features, labels, mode):
|
||||||
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
|||||||
ckpt = CheckpointState()
|
ckpt = CheckpointState()
|
||||||
text_format.Merge(file_content, ckpt)
|
text_format.Merge(file_content, ckpt)
|
||||||
if not ckpt.model_checkpoint_path:
|
if not ckpt.model_checkpoint_path:
|
||||||
raise ValueError("Invalid checkpoint state loaded from %s",
|
raise ValueError("Invalid checkpoint state loaded from "
|
||||||
checkpoint_dir)
|
+ checkpoint_dir)
|
||||||
# For relative model_checkpoint_path and all_model_checkpoint_paths,
|
# For relative model_checkpoint_path and all_model_checkpoint_paths,
|
||||||
# prepend checkpoint_dir.
|
# prepend checkpoint_dir.
|
||||||
if not os.path.isabs(ckpt.model_checkpoint_path):
|
if not os.path.isabs(ckpt.model_checkpoint_path):
|
||||||
@ -1706,12 +1705,17 @@ class Saver(object):
|
|||||||
save_path: Path where parameters were previously saved.
|
save_path: Path where parameters were previously saved.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If save_path is None.
|
ValueError: If save_path is None or not a valid checkpoint.
|
||||||
"""
|
"""
|
||||||
if self._is_empty:
|
if self._is_empty:
|
||||||
return
|
return
|
||||||
if save_path is None:
|
if save_path is None:
|
||||||
raise ValueError("Can't load save_path when it is None.")
|
raise ValueError("Can't load save_path when it is None.")
|
||||||
|
|
||||||
|
if not checkpoint_exists(compat.as_text(save_path)):
|
||||||
|
raise ValueError("The passed save_path is not a valid checkpoint: "
|
||||||
|
+ compat.as_text(save_path))
|
||||||
|
|
||||||
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
||||||
try:
|
try:
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -1719,23 +1723,24 @@ class Saver(object):
|
|||||||
else:
|
else:
|
||||||
sess.run(self.saver_def.restore_op_name,
|
sess.run(self.saver_def.restore_op_name,
|
||||||
{self.saver_def.filename_tensor_name: save_path})
|
{self.saver_def.filename_tensor_name: save_path})
|
||||||
except errors.NotFoundError:
|
except errors.NotFoundError as err:
|
||||||
exception_type, exception_value, exception_traceback = sys.exc_info()
|
# There are three common conditions that might cause this error:
|
||||||
# The checkpoint would not be loaded successfully as is. Try to parse it
|
# 0. The file is missing. We ignore here, as this is checked above.
|
||||||
# as an object-based checkpoint.
|
# 1. This is an object-based checkpoint trying name-based loading.
|
||||||
should_reraise = False
|
# 2. The graph has been altered and a variable or other name is missing.
|
||||||
|
|
||||||
|
# 1. The checkpoint would not be loaded successfully as is. Try to parse
|
||||||
|
# it as an object-based checkpoint.
|
||||||
try:
|
try:
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||||
object_graph_string = reader.get_tensor(
|
object_graph_string = reader.get_tensor(
|
||||||
checkpointable.OBJECT_GRAPH_PROTO_KEY)
|
checkpointable.OBJECT_GRAPH_PROTO_KEY)
|
||||||
except errors.NotFoundError:
|
except errors.NotFoundError:
|
||||||
# This is not an object-based checkpoint, or the checkpoint doesn't
|
# 2. This is not an object-based checkpoint, which likely means there
|
||||||
# exist. Re-raise the original exception, but do it outside the except
|
# is a graph mismatch. Re-raise the original error with
|
||||||
# block so the object graph lookup isn't included in the stack trace.
|
# a helpful message (b/110263146)
|
||||||
should_reraise = True
|
raise _wrap_restore_error_with_msg(
|
||||||
if should_reraise:
|
err, "a Variable name or other graph key that is missing")
|
||||||
six.reraise(exception_type, exception_value, exception_traceback)
|
|
||||||
del exception_traceback # avoid reference cycles
|
|
||||||
|
|
||||||
# This is an object-based checkpoint. We'll print a warning and then do
|
# This is an object-based checkpoint. We'll print a warning and then do
|
||||||
# the restore.
|
# the restore.
|
||||||
@ -1747,6 +1752,11 @@ class Saver(object):
|
|||||||
self._restore_from_object_based_checkpoint(
|
self._restore_from_object_based_checkpoint(
|
||||||
sess=sess, save_path=save_path,
|
sess=sess, save_path=save_path,
|
||||||
object_graph_string=object_graph_string)
|
object_graph_string=object_graph_string)
|
||||||
|
except errors.InvalidArgumentError as err:
|
||||||
|
# There is a mismatch between the graph and the checkpoint being loaded.
|
||||||
|
# We add a more reasonable error message here to help users (b/110263146)
|
||||||
|
raise _wrap_restore_error_with_msg(
|
||||||
|
err, "a mismatch between the current graph and the graph")
|
||||||
|
|
||||||
def _restore_from_object_based_checkpoint(self, sess, save_path,
|
def _restore_from_object_based_checkpoint(self, sess, save_path,
|
||||||
object_graph_string):
|
object_graph_string):
|
||||||
@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
|
|||||||
return meta_graph_filename
|
return meta_graph_filename
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_restore_error_with_msg(err, extra_verbiage):
|
||||||
|
err_msg = ("Restoring from checkpoint failed. This is most likely "
|
||||||
|
"due to {} from the checkpoint. Please ensure that you "
|
||||||
|
"have not altered the graph expected based on the checkpoint. "
|
||||||
|
"Original error:\n\n{}").format(extra_verbiage, err.message)
|
||||||
|
return err.__class__(err.node_def, err.op, err_msg)
|
||||||
|
|
||||||
|
|
||||||
ops.register_proto_function(
|
ops.register_proto_function(
|
||||||
ops.GraphKeys.SAVERS,
|
ops.GraphKeys.SAVERS,
|
||||||
proto_type=saver_pb2.SaverDef,
|
proto_type=saver_pb2.SaverDef,
|
||||||
|
@ -24,10 +24,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -369,8 +367,8 @@ class SaverTest(test.TestCase):
|
|||||||
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
|
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
save = saver_module.Saver({"v0": v0}, write_version=ver)
|
save = saver_module.Saver({"v0": v0}, write_version=ver)
|
||||||
with self.assertRaisesRegexp(errors.NotFoundError,
|
with self.assertRaisesRegexp(
|
||||||
"Failed to find any matching files for"):
|
ValueError, "The passed save_path is not a valid checkpoint:"):
|
||||||
save.restore(sess, "invalid path")
|
save.restore(sess, "invalid path")
|
||||||
|
|
||||||
def testInt64(self):
|
def testInt64(self):
|
||||||
@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
|
|||||||
errors.NotFoundError, "Key b not found in checkpoint"):
|
errors.NotFoundError, "Key b not found in checkpoint"):
|
||||||
b_saver.restore(sess=sess, save_path=save_path)
|
b_saver.restore(sess=sess, save_path=save_path)
|
||||||
|
|
||||||
def testCheckpointNotFoundErrorRaised(self):
|
with self.assertRaises(errors.NotFoundError) as cs:
|
||||||
# Restore does some tricky exception handling to figure out if it should
|
b_saver.restore(sess=sess, save_path=save_path)
|
||||||
# load an object-based checkpoint. Tests that the exception handling isn't
|
|
||||||
# too broad.
|
# Make sure we don't have a confusing "During handling of the above
|
||||||
a = resource_variable_ops.ResourceVariable(1., name="a")
|
# exception" block in Python 3.
|
||||||
saver = saver_module.Saver([a])
|
self.assertNotIn("NewCheckpointReader", cs.exception.message)
|
||||||
with self.test_session() as sess:
|
|
||||||
with self.assertRaisesRegexp(
|
def testGraphChangedForRestoreErrorRaised(self):
|
||||||
errors.NotFoundError,
|
checkpoint_directory = self.get_temp_dir()
|
||||||
"Failed to find any matching files for path_which_does_not_exist"):
|
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||||
saver.restore(sess=sess, save_path="path_which_does_not_exist")
|
|
||||||
try:
|
with ops_lib.Graph().as_default() as g:
|
||||||
saver.restore(sess=sess, save_path="path_which_does_not_exist")
|
a = variables.Variable(1., name="a")
|
||||||
except errors.NotFoundError:
|
a_saver = saver_module.Saver([a])
|
||||||
# Make sure we don't have a confusing "During handling of the above
|
|
||||||
# exception" block in Python 3.
|
with self.test_session(graph=g) as sess:
|
||||||
# pylint: disable=no-value-for-parameter
|
sess.run(a.initializer)
|
||||||
exception_string = "\n".join(
|
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
|
||||||
traceback.format_exception(*sys.exc_info()))
|
|
||||||
# pylint: enable=no-value-for-parameter
|
with ops_lib.Graph().as_default() as g:
|
||||||
self.assertNotIn("NewCheckpointReader", exception_string)
|
a = variables.Variable([1.], name="a")
|
||||||
|
a_saver = saver_module.Saver([a])
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"a mismatch between the current graph and the graph"):
|
||||||
|
a_saver.restore(sess=sess, save_path=save_path)
|
||||||
|
|
||||||
def testLoadFromObjectBasedGraph(self):
|
def testLoadFromObjectBasedGraph(self):
|
||||||
checkpoint_directory = self.get_temp_dir()
|
checkpoint_directory = self.get_temp_dir()
|
||||||
|
Loading…
Reference in New Issue
Block a user