From 25ea8bf4b58d8bdc6ca75afea8648aeef3cdb39f Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 19 Dec 2017 17:13:50 -0800 Subject: [PATCH] tfdbg: send graph traceback and source files from TensorBoard wrapper and hook To this end: * Refactor some common, shared constants and methods to a new file: common.py PiperOrigin-RevId: 179624165 --- tensorflow/python/debug/BUILD | 29 ++++++- tensorflow/python/debug/cli/cli_shared.py | 54 ++---------- tensorflow/python/debug/lib/common.py | 87 +++++++++++++++++++ tensorflow/python/debug/lib/common_test.py | 59 +++++++++++++ .../debug/lib/grpc_debug_test_server.py | 5 +- .../debug/lib/session_debug_grpc_test.py | 53 ++++++++++- tensorflow/python/debug/lib/source_remote.py | 6 ++ .../python/debug/wrappers/grpc_wrapper.py | 70 ++++++++++++++- tensorflow/python/debug/wrappers/hooks.py | 12 +++ .../debug/wrappers/local_cli_wrapper.py | 5 +- 10 files changed, 322 insertions(+), 58 deletions(-) create mode 100644 tensorflow/python/debug/lib/common.py create mode 100644 tensorflow/python/debug/lib/common_test.py diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 789771508e2..5d20701d8fd 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -52,6 +52,12 @@ py_library( ]), ) +py_library( + name = "common", + srcs = ["lib/common.py"], + srcs_version = "PY2AND3", +) + py_library( name = "debug_graphs", srcs = ["lib/debug_graphs.py"], @@ -117,6 +123,7 @@ py_library( srcs = ["lib/source_remote.py"], srcs_version = "PY2AND3", deps = [ + ":common", ":debug_service_pb2_grpc", "//tensorflow/core/debug:debug_service_proto_py", "//tensorflow/python/profiler:tfprof_logger", @@ -193,6 +200,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":command_parser", + ":common", ":debugger_cli_common", ":tensor_format", "//tensorflow/python:framework_for_generated_wrappers", @@ -334,7 +342,11 @@ py_library( name = "grpc_wrapper", srcs = ["wrappers/grpc_wrapper.py"], srcs_version = "PY2AND3", - deps = [":framework"], + deps = [ + ":common", + ":framework", + ":source_remote", + ], ) py_library( @@ -345,6 +357,7 @@ py_library( ":analyzer_cli", ":cli_shared", ":command_parser", + ":common", ":debug_data", ":debugger_cli_common", ":framework", @@ -439,6 +452,20 @@ py_binary( ], ) +py_test( + name = "common_test", + size = "small", + srcs = ["lib/common_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:platform_test", + ], +) + py_test( name = "debug_graphs_test", size = "small", diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index df972eacf73..24431742eb0 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -25,6 +25,7 @@ import six from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import tensor_format +from tensorflow.python.debug.lib import common from tensorflow.python.framework import ops from tensorflow.python.ops import variables @@ -214,51 +215,6 @@ def error(msg): RL("ERROR: " + msg, COLOR_RED)]) -def get_graph_element_name(elem): - """Obtain the name or string representation of a graph element. - - If the graph element has the attribute "name", return name. Otherwise, return - a __str__ representation of the graph element. Certain graph elements, such as - `SparseTensor`s, do not have the attribute "name". - - Args: - elem: The graph element in question. - - Returns: - If the attribute 'name' is available, return the name. Otherwise, return - str(fetch). - """ - - return elem.name if hasattr(elem, "name") else str(elem) - - -def _get_fetch_names(fetches): - """Get a flattened list of the names in run() call fetches. - - Args: - fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an - Operation or a Variable. It may also be nested lists, tuples or - dicts. See doc of `Session.run()` for more details. - - Returns: - (list of str) A flattened list of fetch names from `fetches`. - """ - - lines = [] - if isinstance(fetches, (list, tuple)): - for fetch in fetches: - lines.extend(_get_fetch_names(fetch)) - elif isinstance(fetches, dict): - for key in fetches: - lines.extend(_get_fetch_names(fetches[key])) - else: - # This ought to be a Tensor, an Operation or a Variable, for which the name - # attribute should be available. (Bottom-out condition of the recursion.) - lines.append(get_graph_element_name(fetches)) - - return lines - - def _recommend_command(command, description, indent=2, create_link=False): """Generate a RichTextLines object that describes a recommended command. @@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count, (RichTextLines) Formatted intro message about the `Session.run()` call. """ - fetch_lines = _get_fetch_names(fetches) + fetch_lines = common.get_flattened_names(fetches) if not feed_dict: feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] else: feed_dict_lines = [] for feed_key in feed_dict: - feed_key_name = get_graph_element_name(feed_key) + feed_key_name = common.get_graph_element_name(feed_key) feed_dict_line = debugger_cli_common.RichLine(" ") feed_dict_line += debugger_cli_common.RichLine( feed_key_name, @@ -446,10 +402,10 @@ def get_run_short_description(run_call_count, description = "run #%d: " % run_call_count if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): - description += "1 fetch (%s); " % get_graph_element_name(fetches) + description += "1 fetch (%s); " % common.get_graph_element_name(fetches) else: # Could be (nested) list, tuple, dict or namedtuple. - num_fetches = len(_get_fetch_names(fetches)) + num_fetches = len(common.get_flattened_names(fetches)) if num_fetches > 1: description += "%d fetches; " % num_fetches else: diff --git a/tensorflow/python/debug/lib/common.py b/tensorflow/python/debug/lib/common.py new file mode 100644 index 00000000000..19a0d8c5010 --- /dev/null +++ b/tensorflow/python/debug/lib/common.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common values and methods for TensorFlow Debugger.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json + +GRPC_URL_PREFIX = "grpc://" + +# A key for a Session.run() call. +RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"]) + + +def get_graph_element_name(elem): + """Obtain the name or string representation of a graph element. + + If the graph element has the attribute "name", return name. Otherwise, return + a __str__ representation of the graph element. Certain graph elements, such as + `SparseTensor`s, do not have the attribute "name". + + Args: + elem: The graph element in question. + + Returns: + If the attribute 'name' is available, return the name. Otherwise, return + str(fetch). + """ + + return elem.name if hasattr(elem, "name") else str(elem) + + +def get_flattened_names(feeds_or_fetches): + """Get a flattened list of the names in run() call feeds or fetches. + + Args: + feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe + a Tensor, an Operation or a Variable. It may also be nested lists, tuples + or dicts. See doc of `Session.run()` for more details. + + Returns: + (list of str) A flattened list of fetch names from `feeds_or_fetches`. + """ + + lines = [] + if isinstance(feeds_or_fetches, (list, tuple)): + for item in feeds_or_fetches: + lines.extend(get_flattened_names(item)) + elif isinstance(feeds_or_fetches, dict): + for key in feeds_or_fetches: + lines.extend(get_flattened_names(feeds_or_fetches[key])) + else: + # This ought to be a Tensor, an Operation or a Variable, for which the name + # attribute should be available. (Bottom-out condition of the recursion.) + lines.append(get_graph_element_name(feeds_or_fetches)) + + return lines + + +def get_run_key(feed_dict, fetches): + """Summarize the names of feeds and fetches as a RunKey JSON string. + + Args: + feed_dict: The feed_dict given to the `Session.run()` call. + fetches: The fetches from the `Session.run()` call. + + Returns: + A JSON Array consisting of two items. They first items is a flattened + Array of the names of the feeds. The second item is a flattened Array of + the names of the fetches. + """ + return json.dumps(RunKey(get_flattened_names(feed_dict), + get_flattened_names(fetches))) diff --git a/tensorflow/python/debug/lib/common_test.py b/tensorflow/python/debug/lib/common_test.py new file mode 100644 index 00000000000..5af0dafcf9f --- /dev/null +++ b/tensorflow/python/debug/lib/common_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for common values and methods of TensorFlow Debugger.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.python.debug.lib import common +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class CommonTest(test_util.TensorFlowTestCase): + + def testOnFeedOneFetch(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + run_key = common.get_run_key({"a": a}, [b]) + loaded = json.loads(run_key) + self.assertItemsEqual(["a:0"], loaded[0]) + self.assertItemsEqual(["b:0"], loaded[1]) + + def testGetRunKeyFlat(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + run_key = common.get_run_key({"a": a}, [a, b]) + loaded = json.loads(run_key) + self.assertItemsEqual(["a:0"], loaded[0]) + self.assertItemsEqual(["a:0", "b:0"], loaded[1]) + + def testGetRunKeyNestedFetches(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + c = constant_op.constant(30.0, name="c") + d = constant_op.constant(30.0, name="d") + run_key = common.get_run_key( + {}, {"set1": [a, b], "set2": {"c": c, "d": d}}) + loaded = json.loads(run_key) + self.assertItemsEqual([], loaded[0]) + self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1]) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py index a637677d7d0..91700469484 100644 --- a/tensorflow/python/debug/lib/grpc_debug_test_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py @@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): op_log_proto.id_to_string) raise ValueError( "Op '%s' does not exist in the tracebacks received by the debug " - "server.") + "server." % op_name) def query_origin_stack(self): """Query the stack of the origin of the execution call. @@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): Raises: ValueError: If no source file is found at the given file_path. """ + if not self._source_files: + raise ValueError( + "This debug server has not received any source file contents yet.") for source_file_proto in self._source_files.source_files: if source_file_proto.file_path == file_path: return source_file_proto.lines[lineno - 1] diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index 99781bd9d90..068e4f81c0d 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): self.assertEqual( 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0])) - def testTensorBoardDebugHooWorks(self): + def testTensorBoardDebugHookWorks(self): u = variables.Variable(2.1, name="u") v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") @@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): ["localhost:%d" % self._server_port]) sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) + # Activate watch point on some a tensor before calling sess.run(). + self._server.request_watch("u/read", 0, "DebugIdentity") self.assertAllClose(42.0, sess.run(w)) + # self.assertAllClose(42.0, sess.run(w)) + dump = debug_data.DebugDumpDir(self._dump_root) + self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) + + # Check that the server has received the stack trace. + self.assertTrue(self._server.query_op_traceback("u")) + self.assertTrue(self._server.query_op_traceback("u/read")) + self.assertTrue(self._server.query_op_traceback("v")) + self.assertTrue(self._server.query_op_traceback("v/read")) + self.assertTrue(self._server.query_op_traceback("w")) + + # Check that the server has received the python file content. + # Query an arbitrary line to make sure that is the case. + with open(__file__, "rt") as this_source_file: + first_line = this_source_file.readline().strip() + self.assertEqual( + first_line, self._server.query_source_file_line(__file__, 1)) + + self._server.clear_data() + # Call sess.run() again, and verify that this time the traceback and source + # code is not sent, because the graph version is not newer. + self.assertAllClose(42.0, sess.run(w)) + with self.assertRaises(ValueError): + self._server.query_op_traceback("delta_1") + with self.assertRaises(ValueError): + self._server.query_source_file_line(__file__, 1) + def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self): hooks.GrpcDebugHook(["grpc://foo:42424"]) hooks.GrpcDebugHook(["foo:42424"]) @@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): # to disable the breakpoint at delta:0:DebugIdentity. self.assertSetEqual(set(), self._server_1.breakpoints) + if i == 0: + # Check that the server has received the stack trace. + self.assertTrue(self._server_1.query_op_traceback("delta_1")) + self.assertTrue(self._server_1.query_op_traceback("delta_2")) + self.assertTrue(self._server_1.query_op_traceback("inc_v_1")) + self.assertTrue(self._server_1.query_op_traceback("inc_v_2")) + # Check that the server has received the python file content. + # Query an arbitrary line to make sure that is the case. + with open(__file__, "rt") as this_source_file: + first_line = this_source_file.readline().strip() + self.assertEqual( + first_line, self._server_1.query_source_file_line(__file__, 1)) + else: + # In later Session.run() calls, the traceback shouldn't have been sent + # because it is already sent in the 1st call. So calling + # query_op_traceback() should lead to an exception, because the test + # debug server clears the data at the beginning of every iteration. + with self.assertRaises(ValueError): + self._server_1.query_op_traceback("delta_1") + with self.assertRaises(ValueError): + self._server_1.query_source_file_line(__file__, 1) + def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): with session.Session() as sess: v = variables.Variable(50.0, name="v") diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py index 9d10d5a8d11..7fd8ceca1dd 100644 --- a/tensorflow/python/debug/lib/source_remote.py +++ b/tensorflow/python/debug/lib/source_remote.py @@ -24,6 +24,7 @@ import grpc from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.protobuf import debug_pb2 +from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import source_utils from tensorflow.python.platform import gfile @@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations, """ if not isinstance(destinations, list): destinations = [destinations] + # Strip grpc:// prefix, if any is present. + destinations = [ + dest[len(common.GRPC_URL_PREFIX):] + if dest.startswith(common.GRPC_URL_PREFIX) else dest + for dest in destinations] call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION if is_eager_execution diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 16b2018b413..cb9bf95782c 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -17,15 +17,55 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import traceback + # Google-internal import(s). +from tensorflow.python.debug.lib import common from tensorflow.python.debug.wrappers import framework +def publish_traceback(debug_server_urls, + graph, + feed_dict, + fetches, + old_graph_version): + """Publish traceback and source code if graph version is new. + + `graph.version` is compared with `old_graph_version`. If the former is higher + (i.e., newer), the graph traceback and the associated source code is sent to + the debug server at the specified gRPC URLs. + + Args: + debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of + debug server URLs. + graph: A Python `tf.Graph` object. + feed_dict: Feed dictionary given to the `Session.run()` call. + fetches: Fetches from the `Session.run()` call. + old_graph_version: Old graph version to compare to. + + Returns: + If `graph.version > old_graph_version`, the new graph version as an `int`. + Else, the `old_graph_version` is returned. + """ + # TODO(cais): Consider moving this back to the top, after grpc becomes a + # pip dependency of tensorflow or tf_debug. + # pylint:disable=g-import-not-at-top + from tensorflow.python.debug.lib import source_remote + # pylint:enable=g-import-not-at-top + if graph.version > old_graph_version: + run_key = common.get_run_key(feed_dict, fetches) + source_remote.send_graph_tracebacks( + debug_server_urls, run_key, traceback.extract_stack(), graph, + send_source=True) + return graph.version + else: + return old_graph_version + + class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): """Debug Session wrapper that send debug data to gRPC stream(s).""" - _GRPC_URL_PREFIX = "grpc://" - def __init__(self, sess, grpc_debug_server_addresses, @@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): return self._grpc_debug_server_urls def _normalize_grpc_url(self, address): - return (self._GRPC_URL_PREFIX + address - if not address.startswith(self._GRPC_URL_PREFIX) else address) + return (common.GRPC_URL_PREFIX + address + if not address.startswith(common.GRPC_URL_PREFIX) else address) class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): @@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): watch_fn=_gated_grpc_watch_fn, thread_name_filter=thread_name_filter, log_usage=log_usage) + + # Keeps track of the latest version of Python graph object that has been + # sent to the debug servers. + self._sent_graph_version = -sys.maxint + + def run(self, + fetches, + feed_dict=None, + options=None, + run_metadata=None, + callable_runner=None, + callable_runner_args=None): + self._sent_graph_version = publish_traceback( + self._grpc_debug_server_urls, self.graph, feed_dict, fetches, + self._sent_graph_version) + return super(TensorBoardDebugWrapperSession, self).run( + fetches, + feed_dict=feed_dict, + options=options, + run_metadata=run_metadata, + callable_runner=callable_runner, + callable_runner_args=callable_runner_args) diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 43066996248..aa9f0650406 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import stepper @@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook): watch_fn=_gated_grpc_watch_fn, thread_name_filter=thread_name_filter, log_usage=log_usage) + + self._grpc_debug_server_addresses = grpc_debug_server_addresses + self._sent_graph_version = -sys.maxint + + def before_run(self, run_context): + self._sent_graph_version = grpc_wrapper.publish_traceback( + self._grpc_debug_server_addresses, run_context.session.graph, + run_context.original_args.feed_dict, run_context.original_args.fetches, + self._sent_graph_version) + return super(TensorBoardDebugHook, self).before_run(run_context) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 5bf6d9d1f4a..c46a4e7d1aa 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import profile_analyzer_cli from tensorflow.python.debug.cli import stepper_cli from tensorflow.python.debug.cli import ui_factory +from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import framework @@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): feed_key = None feed_value = None for key in self._feed_dict: - key_name = cli_shared.get_graph_element_name(key) + key_name = common.get_graph_element_name(key) if key_name == tensor_name: feed_key = key_name feed_value = self._feed_dict[key] @@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): list(self._tensor_filters.keys())) if self._feed_dict: # Register tab completion for feed_dict keys. - feed_keys = [cli_shared.get_graph_element_name(key) + feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)