tfdbg: send graph traceback and source files from TensorBoard wrapper and hook
To this end: * Refactor some common, shared constants and methods to a new file: common.py PiperOrigin-RevId: 179624165
This commit is contained in:
parent
ef523e77ca
commit
25ea8bf4b5
@ -52,6 +52,12 @@ py_library(
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "common",
|
||||
srcs = ["lib/common.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "debug_graphs",
|
||||
srcs = ["lib/debug_graphs.py"],
|
||||
@ -117,6 +123,7 @@ py_library(
|
||||
srcs = ["lib/source_remote.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
":debug_service_pb2_grpc",
|
||||
"//tensorflow/core/debug:debug_service_proto_py",
|
||||
"//tensorflow/python/profiler:tfprof_logger",
|
||||
@ -193,6 +200,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":command_parser",
|
||||
":common",
|
||||
":debugger_cli_common",
|
||||
":tensor_format",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
@ -334,7 +342,11 @@ py_library(
|
||||
name = "grpc_wrapper",
|
||||
srcs = ["wrappers/grpc_wrapper.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":framework"],
|
||||
deps = [
|
||||
":common",
|
||||
":framework",
|
||||
":source_remote",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -345,6 +357,7 @@ py_library(
|
||||
":analyzer_cli",
|
||||
":cli_shared",
|
||||
":command_parser",
|
||||
":common",
|
||||
":debug_data",
|
||||
":debugger_cli_common",
|
||||
":framework",
|
||||
@ -439,6 +452,20 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "common_test",
|
||||
size = "small",
|
||||
srcs = ["lib/common_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "debug_graphs_test",
|
||||
size = "small",
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import tensor_format
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
@ -214,51 +215,6 @@ def error(msg):
|
||||
RL("ERROR: " + msg, COLOR_RED)])
|
||||
|
||||
|
||||
def get_graph_element_name(elem):
|
||||
"""Obtain the name or string representation of a graph element.
|
||||
|
||||
If the graph element has the attribute "name", return name. Otherwise, return
|
||||
a __str__ representation of the graph element. Certain graph elements, such as
|
||||
`SparseTensor`s, do not have the attribute "name".
|
||||
|
||||
Args:
|
||||
elem: The graph element in question.
|
||||
|
||||
Returns:
|
||||
If the attribute 'name' is available, return the name. Otherwise, return
|
||||
str(fetch).
|
||||
"""
|
||||
|
||||
return elem.name if hasattr(elem, "name") else str(elem)
|
||||
|
||||
|
||||
def _get_fetch_names(fetches):
|
||||
"""Get a flattened list of the names in run() call fetches.
|
||||
|
||||
Args:
|
||||
fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
|
||||
Operation or a Variable. It may also be nested lists, tuples or
|
||||
dicts. See doc of `Session.run()` for more details.
|
||||
|
||||
Returns:
|
||||
(list of str) A flattened list of fetch names from `fetches`.
|
||||
"""
|
||||
|
||||
lines = []
|
||||
if isinstance(fetches, (list, tuple)):
|
||||
for fetch in fetches:
|
||||
lines.extend(_get_fetch_names(fetch))
|
||||
elif isinstance(fetches, dict):
|
||||
for key in fetches:
|
||||
lines.extend(_get_fetch_names(fetches[key]))
|
||||
else:
|
||||
# This ought to be a Tensor, an Operation or a Variable, for which the name
|
||||
# attribute should be available. (Bottom-out condition of the recursion.)
|
||||
lines.append(get_graph_element_name(fetches))
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _recommend_command(command, description, indent=2, create_link=False):
|
||||
"""Generate a RichTextLines object that describes a recommended command.
|
||||
|
||||
@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count,
|
||||
(RichTextLines) Formatted intro message about the `Session.run()` call.
|
||||
"""
|
||||
|
||||
fetch_lines = _get_fetch_names(fetches)
|
||||
fetch_lines = common.get_flattened_names(fetches)
|
||||
|
||||
if not feed_dict:
|
||||
feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")]
|
||||
else:
|
||||
feed_dict_lines = []
|
||||
for feed_key in feed_dict:
|
||||
feed_key_name = get_graph_element_name(feed_key)
|
||||
feed_key_name = common.get_graph_element_name(feed_key)
|
||||
feed_dict_line = debugger_cli_common.RichLine(" ")
|
||||
feed_dict_line += debugger_cli_common.RichLine(
|
||||
feed_key_name,
|
||||
@ -446,10 +402,10 @@ def get_run_short_description(run_call_count,
|
||||
description = "run #%d: " % run_call_count
|
||||
|
||||
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
|
||||
description += "1 fetch (%s); " % get_graph_element_name(fetches)
|
||||
description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
|
||||
else:
|
||||
# Could be (nested) list, tuple, dict or namedtuple.
|
||||
num_fetches = len(_get_fetch_names(fetches))
|
||||
num_fetches = len(common.get_flattened_names(fetches))
|
||||
if num_fetches > 1:
|
||||
description += "%d fetches; " % num_fetches
|
||||
else:
|
||||
|
87
tensorflow/python/debug/lib/common.py
Normal file
87
tensorflow/python/debug/lib/common.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common values and methods for TensorFlow Debugger."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
GRPC_URL_PREFIX = "grpc://"
|
||||
|
||||
# A key for a Session.run() call.
|
||||
RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"])
|
||||
|
||||
|
||||
def get_graph_element_name(elem):
|
||||
"""Obtain the name or string representation of a graph element.
|
||||
|
||||
If the graph element has the attribute "name", return name. Otherwise, return
|
||||
a __str__ representation of the graph element. Certain graph elements, such as
|
||||
`SparseTensor`s, do not have the attribute "name".
|
||||
|
||||
Args:
|
||||
elem: The graph element in question.
|
||||
|
||||
Returns:
|
||||
If the attribute 'name' is available, return the name. Otherwise, return
|
||||
str(fetch).
|
||||
"""
|
||||
|
||||
return elem.name if hasattr(elem, "name") else str(elem)
|
||||
|
||||
|
||||
def get_flattened_names(feeds_or_fetches):
|
||||
"""Get a flattened list of the names in run() call feeds or fetches.
|
||||
|
||||
Args:
|
||||
feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe
|
||||
a Tensor, an Operation or a Variable. It may also be nested lists, tuples
|
||||
or dicts. See doc of `Session.run()` for more details.
|
||||
|
||||
Returns:
|
||||
(list of str) A flattened list of fetch names from `feeds_or_fetches`.
|
||||
"""
|
||||
|
||||
lines = []
|
||||
if isinstance(feeds_or_fetches, (list, tuple)):
|
||||
for item in feeds_or_fetches:
|
||||
lines.extend(get_flattened_names(item))
|
||||
elif isinstance(feeds_or_fetches, dict):
|
||||
for key in feeds_or_fetches:
|
||||
lines.extend(get_flattened_names(feeds_or_fetches[key]))
|
||||
else:
|
||||
# This ought to be a Tensor, an Operation or a Variable, for which the name
|
||||
# attribute should be available. (Bottom-out condition of the recursion.)
|
||||
lines.append(get_graph_element_name(feeds_or_fetches))
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def get_run_key(feed_dict, fetches):
|
||||
"""Summarize the names of feeds and fetches as a RunKey JSON string.
|
||||
|
||||
Args:
|
||||
feed_dict: The feed_dict given to the `Session.run()` call.
|
||||
fetches: The fetches from the `Session.run()` call.
|
||||
|
||||
Returns:
|
||||
A JSON Array consisting of two items. They first items is a flattened
|
||||
Array of the names of the feeds. The second item is a flattened Array of
|
||||
the names of the fetches.
|
||||
"""
|
||||
return json.dumps(RunKey(get_flattened_names(feed_dict),
|
||||
get_flattened_names(fetches)))
|
59
tensorflow/python/debug/lib/common_test.py
Normal file
59
tensorflow/python/debug/lib/common_test.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit tests for common values and methods of TensorFlow Debugger."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class CommonTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testOnFeedOneFetch(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
run_key = common.get_run_key({"a": a}, [b])
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual(["a:0"], loaded[0])
|
||||
self.assertItemsEqual(["b:0"], loaded[1])
|
||||
|
||||
def testGetRunKeyFlat(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
run_key = common.get_run_key({"a": a}, [a, b])
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual(["a:0"], loaded[0])
|
||||
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
|
||||
|
||||
def testGetRunKeyNestedFetches(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
c = constant_op.constant(30.0, name="c")
|
||||
d = constant_op.constant(30.0, name="d")
|
||||
run_key = common.get_run_key(
|
||||
{}, {"set1": [a, b], "set2": {"c": c, "d": d}})
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual([], loaded[0])
|
||||
self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
|
||||
op_log_proto.id_to_string)
|
||||
raise ValueError(
|
||||
"Op '%s' does not exist in the tracebacks received by the debug "
|
||||
"server.")
|
||||
"server." % op_name)
|
||||
|
||||
def query_origin_stack(self):
|
||||
"""Query the stack of the origin of the execution call.
|
||||
@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
|
||||
Raises:
|
||||
ValueError: If no source file is found at the given file_path.
|
||||
"""
|
||||
if not self._source_files:
|
||||
raise ValueError(
|
||||
"This debug server has not received any source file contents yet.")
|
||||
for source_file_proto in self._source_files.source_files:
|
||||
if source_file_proto.file_path == file_path:
|
||||
return source_file_proto.lines[lineno - 1]
|
||||
|
@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
self.assertEqual(
|
||||
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
|
||||
|
||||
def testTensorBoardDebugHooWorks(self):
|
||||
def testTensorBoardDebugHookWorks(self):
|
||||
u = variables.Variable(2.1, name="u")
|
||||
v = variables.Variable(20.0, name="v")
|
||||
w = math_ops.multiply(u, v, name="w")
|
||||
@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
["localhost:%d" % self._server_port])
|
||||
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
|
||||
|
||||
# Activate watch point on some a tensor before calling sess.run().
|
||||
self._server.request_watch("u/read", 0, "DebugIdentity")
|
||||
self.assertAllClose(42.0, sess.run(w))
|
||||
|
||||
# self.assertAllClose(42.0, sess.run(w))
|
||||
dump = debug_data.DebugDumpDir(self._dump_root)
|
||||
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
|
||||
|
||||
# Check that the server has received the stack trace.
|
||||
self.assertTrue(self._server.query_op_traceback("u"))
|
||||
self.assertTrue(self._server.query_op_traceback("u/read"))
|
||||
self.assertTrue(self._server.query_op_traceback("v"))
|
||||
self.assertTrue(self._server.query_op_traceback("v/read"))
|
||||
self.assertTrue(self._server.query_op_traceback("w"))
|
||||
|
||||
# Check that the server has received the python file content.
|
||||
# Query an arbitrary line to make sure that is the case.
|
||||
with open(__file__, "rt") as this_source_file:
|
||||
first_line = this_source_file.readline().strip()
|
||||
self.assertEqual(
|
||||
first_line, self._server.query_source_file_line(__file__, 1))
|
||||
|
||||
self._server.clear_data()
|
||||
# Call sess.run() again, and verify that this time the traceback and source
|
||||
# code is not sent, because the graph version is not newer.
|
||||
self.assertAllClose(42.0, sess.run(w))
|
||||
with self.assertRaises(ValueError):
|
||||
self._server.query_op_traceback("delta_1")
|
||||
with self.assertRaises(ValueError):
|
||||
self._server.query_source_file_line(__file__, 1)
|
||||
|
||||
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
|
||||
hooks.GrpcDebugHook(["grpc://foo:42424"])
|
||||
hooks.GrpcDebugHook(["foo:42424"])
|
||||
@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||
# to disable the breakpoint at delta:0:DebugIdentity.
|
||||
self.assertSetEqual(set(), self._server_1.breakpoints)
|
||||
|
||||
if i == 0:
|
||||
# Check that the server has received the stack trace.
|
||||
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("delta_2"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
|
||||
# Check that the server has received the python file content.
|
||||
# Query an arbitrary line to make sure that is the case.
|
||||
with open(__file__, "rt") as this_source_file:
|
||||
first_line = this_source_file.readline().strip()
|
||||
self.assertEqual(
|
||||
first_line, self._server_1.query_source_file_line(__file__, 1))
|
||||
else:
|
||||
# In later Session.run() calls, the traceback shouldn't have been sent
|
||||
# because it is already sent in the 1st call. So calling
|
||||
# query_op_traceback() should lead to an exception, because the test
|
||||
# debug server clears the data at the beginning of every iteration.
|
||||
with self.assertRaises(ValueError):
|
||||
self._server_1.query_op_traceback("delta_1")
|
||||
with self.assertRaises(ValueError):
|
||||
self._server_1.query_source_file_line(__file__, 1)
|
||||
|
||||
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
||||
with session.Session() as sess:
|
||||
v = variables.Variable(50.0, name="v")
|
||||
|
@ -24,6 +24,7 @@ import grpc
|
||||
|
||||
from tensorflow.core.debug import debug_service_pb2
|
||||
from tensorflow.core.protobuf import debug_pb2
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations,
|
||||
"""
|
||||
if not isinstance(destinations, list):
|
||||
destinations = [destinations]
|
||||
# Strip grpc:// prefix, if any is present.
|
||||
destinations = [
|
||||
dest[len(common.GRPC_URL_PREFIX):]
|
||||
if dest.startswith(common.GRPC_URL_PREFIX) else dest
|
||||
for dest in destinations]
|
||||
|
||||
call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
|
||||
if is_eager_execution
|
||||
|
@ -17,15 +17,55 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Google-internal import(s).
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.wrappers import framework
|
||||
|
||||
|
||||
def publish_traceback(debug_server_urls,
|
||||
graph,
|
||||
feed_dict,
|
||||
fetches,
|
||||
old_graph_version):
|
||||
"""Publish traceback and source code if graph version is new.
|
||||
|
||||
`graph.version` is compared with `old_graph_version`. If the former is higher
|
||||
(i.e., newer), the graph traceback and the associated source code is sent to
|
||||
the debug server at the specified gRPC URLs.
|
||||
|
||||
Args:
|
||||
debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
|
||||
debug server URLs.
|
||||
graph: A Python `tf.Graph` object.
|
||||
feed_dict: Feed dictionary given to the `Session.run()` call.
|
||||
fetches: Fetches from the `Session.run()` call.
|
||||
old_graph_version: Old graph version to compare to.
|
||||
|
||||
Returns:
|
||||
If `graph.version > old_graph_version`, the new graph version as an `int`.
|
||||
Else, the `old_graph_version` is returned.
|
||||
"""
|
||||
# TODO(cais): Consider moving this back to the top, after grpc becomes a
|
||||
# pip dependency of tensorflow or tf_debug.
|
||||
# pylint:disable=g-import-not-at-top
|
||||
from tensorflow.python.debug.lib import source_remote
|
||||
# pylint:enable=g-import-not-at-top
|
||||
if graph.version > old_graph_version:
|
||||
run_key = common.get_run_key(feed_dict, fetches)
|
||||
source_remote.send_graph_tracebacks(
|
||||
debug_server_urls, run_key, traceback.extract_stack(), graph,
|
||||
send_source=True)
|
||||
return graph.version
|
||||
else:
|
||||
return old_graph_version
|
||||
|
||||
|
||||
class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
|
||||
"""Debug Session wrapper that send debug data to gRPC stream(s)."""
|
||||
|
||||
_GRPC_URL_PREFIX = "grpc://"
|
||||
|
||||
def __init__(self,
|
||||
sess,
|
||||
grpc_debug_server_addresses,
|
||||
@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
|
||||
return self._grpc_debug_server_urls
|
||||
|
||||
def _normalize_grpc_url(self, address):
|
||||
return (self._GRPC_URL_PREFIX + address
|
||||
if not address.startswith(self._GRPC_URL_PREFIX) else address)
|
||||
return (common.GRPC_URL_PREFIX + address
|
||||
if not address.startswith(common.GRPC_URL_PREFIX) else address)
|
||||
|
||||
|
||||
class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
|
||||
@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
|
||||
watch_fn=_gated_grpc_watch_fn,
|
||||
thread_name_filter=thread_name_filter,
|
||||
log_usage=log_usage)
|
||||
|
||||
# Keeps track of the latest version of Python graph object that has been
|
||||
# sent to the debug servers.
|
||||
self._sent_graph_version = -sys.maxint
|
||||
|
||||
def run(self,
|
||||
fetches,
|
||||
feed_dict=None,
|
||||
options=None,
|
||||
run_metadata=None,
|
||||
callable_runner=None,
|
||||
callable_runner_args=None):
|
||||
self._sent_graph_version = publish_traceback(
|
||||
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
|
||||
self._sent_graph_version)
|
||||
return super(TensorBoardDebugWrapperSession, self).run(
|
||||
fetches,
|
||||
feed_dict=feed_dict,
|
||||
options=options,
|
||||
run_metadata=run_metadata,
|
||||
callable_runner=callable_runner,
|
||||
callable_runner_args=callable_runner_args)
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.debug.lib import debug_utils
|
||||
from tensorflow.python.debug.lib import stepper
|
||||
@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook):
|
||||
watch_fn=_gated_grpc_watch_fn,
|
||||
thread_name_filter=thread_name_filter,
|
||||
log_usage=log_usage)
|
||||
|
||||
self._grpc_debug_server_addresses = grpc_debug_server_addresses
|
||||
self._sent_graph_version = -sys.maxint
|
||||
|
||||
def before_run(self, run_context):
|
||||
self._sent_graph_version = grpc_wrapper.publish_traceback(
|
||||
self._grpc_debug_server_addresses, run_context.session.graph,
|
||||
run_context.original_args.feed_dict, run_context.original_args.fetches,
|
||||
self._sent_graph_version)
|
||||
return super(TensorBoardDebugHook, self).before_run(run_context)
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import profile_analyzer_cli
|
||||
from tensorflow.python.debug.cli import stepper_cli
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.wrappers import framework
|
||||
|
||||
@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
feed_key = None
|
||||
feed_value = None
|
||||
for key in self._feed_dict:
|
||||
key_name = cli_shared.get_graph_element_name(key)
|
||||
key_name = common.get_graph_element_name(key)
|
||||
if key_name == tensor_name:
|
||||
feed_key = key_name
|
||||
feed_value = self._feed_dict[key]
|
||||
@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
list(self._tensor_filters.keys()))
|
||||
if self._feed_dict:
|
||||
# Register tab completion for feed_dict keys.
|
||||
feed_keys = [cli_shared.get_graph_element_name(key)
|
||||
feed_keys = [common.get_graph_element_name(key)
|
||||
for key in self._feed_dict.keys()]
|
||||
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user