Improve tfdbg's handling of runtime errors

* In some cases the RuntimeError object (tf_error in cli_shared.py) doesn't have
  the op or its name available. Handle that situation properly.
* Previously, we used client graphs in the debugger CLI whenever it's available. This
  has caused issues in which the device names
  (e.g., "/device:GPU:0" vs "/job:localhost/replica:0/task:0/device:CPU:0").
  This CL fixes that by favoring the runtime graph on the disk over the client graph.
  The former has the actual device names.
  Use the latter only if the former isn't available for some reason (e.g.,
  writing graph to the disk failed.)

PiperOrigin-RevId: 200128582
This commit is contained in:
Shanqing Cai 2018-06-11 15:57:39 -07:00 committed by TensorFlower Gardener
parent b12f58cfcf
commit 49ed096fb3
4 changed files with 57 additions and 39 deletions

View File

@ -451,42 +451,48 @@ def get_error_intro(tf_error):
sample commands for debugging.
"""
op_name = tf_error.op.name
if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
op_name = tf_error.op.name
else:
op_name = None
intro_lines = [
"--------------------------------------",
RL("!!! An error occurred during the run !!!", "blink"),
"",
"You may use the following commands to debug:",
]
out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
out.extend(
_recommend_command("ni -a -d -t %s" % op_name,
"Inspect information about the failing op.",
create_link=True))
out.extend(
_recommend_command("li -r %s" % op_name,
"List inputs to the failing op, recursively.",
create_link=True))
if op_name is not None:
out.extend(debugger_cli_common.RichTextLines(
["You may use the following commands to debug:"]))
out.extend(
_recommend_command("ni -a -d -t %s" % op_name,
"Inspect information about the failing op.",
create_link=True))
out.extend(
_recommend_command("li -r %s" % op_name,
"List inputs to the failing op, recursively.",
create_link=True))
out.extend(
_recommend_command(
"lt",
"List all tensors dumped during the failing run() call.",
create_link=True))
out.extend(
_recommend_command(
"lt",
"List all tensors dumped during the failing run() call.",
create_link=True))
else:
out.extend(debugger_cli_common.RichTextLines([
"WARNING: Cannot determine the name of the op that caused the error."]))
more_lines = [
"",
"Op name: " + op_name,
"Op name: %s" % op_name,
"Error type: " + str(type(tf_error)),
"",
"Details:",
str(tf_error),
"",
"WARNING: Using client GraphDef due to the error, instead of "
"executor GraphDefs.",
"--------------------------------------",
"",
]

View File

@ -372,6 +372,11 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
self.assertEqual("Details:", error_intro.lines[14])
self.assertStartsWith(error_intro.lines[15], "foo description")
def testGetErrorIntroForNoOpName(self):
tf_error = errors.OpError(None, None, "Fake OpError", -1)
error_intro = cli_shared.get_error_intro(tf_error)
self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
if __name__ == "__main__":
googletest.main()

View File

@ -69,6 +69,12 @@ run
exit
EOF
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline
run
ni -a -d -t v/read
exit
EOF
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
run -t 1
run --node_name_filter hidden --op_type_filter MatMul

View File

@ -748,7 +748,7 @@ class DebugDumpDir(object):
return sum(len(self._dump_tensor_data[device_name])
for device_name in self._dump_tensor_data)
def _load_partition_graphs(self, partition_graphs, validate):
def _load_partition_graphs(self, client_partition_graphs, validate):
"""Load and process partition graphs.
Load the graphs; parse the input and control input structure; obtain the
@ -757,8 +757,10 @@ class DebugDumpDir(object):
tensor dumps.
Args:
partition_graphs: A repeated field of GraphDefs representing the
partition graphs executed by the TensorFlow runtime.
client_partition_graphs: A repeated field of GraphDefs representing the
partition graphs executed by the TensorFlow runtime, from the Python
client. These partition graphs are used only if partition graphs
cannot be loaded from the dump directory on the file system.
validate: (`bool`) Whether the dump files are to be validated against the
partition graphs.
@ -769,24 +771,23 @@ class DebugDumpDir(object):
self._debug_graphs = {}
self._node_devices = {}
if partition_graphs:
partition_graphs_and_device_names = [
(partition_graph, None) for partition_graph in partition_graphs]
else:
partition_graphs_and_device_names = []
for device_name in self._device_names:
partition_graph = None
if device_name in self._dump_graph_file_paths:
partition_graph = _load_graph_def_from_event_file(
self._dump_graph_file_paths[device_name])
else:
partition_graph = self._find_partition_graph(partition_graphs,
device_name)
if partition_graph:
partition_graphs_and_device_names.append((partition_graph,
device_name))
else:
logging.warn("Failed to load partition graphs from disk.")
partition_graphs_and_device_names = []
for device_name in self._device_names:
partition_graph = None
if device_name in self._dump_graph_file_paths:
partition_graph = _load_graph_def_from_event_file(
self._dump_graph_file_paths[device_name])
else:
logging.warn(
"Failed to load partition graphs for device %s from disk. "
"As a fallback, the client graphs will be used. This "
"may cause mismatches in device names." % device_name)
partition_graph = self._find_partition_graph(client_partition_graphs,
device_name)
if partition_graph:
partition_graphs_and_device_names.append((partition_graph,
device_name))
for partition_graph, maybe_device_name in partition_graphs_and_device_names:
debug_graph = debug_graphs.DebugGraph(partition_graph,