[tfdbg] Fix gRPC message length limit issue in source remote
Fixes https://github.com/tensorflow/tensorboard/issues/1103 PiperOrigin-RevId: 257419107
This commit is contained in:
parent
d7fbbc0023
commit
0f5d0cfc2a
@ -346,7 +346,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
if self._server_started:
|
if self._server_started:
|
||||||
raise ValueError("Server has already started running")
|
raise ValueError("Server has already started running")
|
||||||
|
|
||||||
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
|
||||||
|
("grpc.max_send_message_length", -1)]
|
||||||
|
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
|
||||||
|
options=no_max_message_sizes)
|
||||||
debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
|
debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
|
||||||
self.server)
|
self.server)
|
||||||
self.server.add_insecure_port("[::]:%d" % self._server_port)
|
self.server.add_insecure_port("[::]:%d" % self._server_port)
|
||||||
|
@ -28,7 +28,6 @@ from tensorflow.python.debug.lib import common
|
|||||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||||
from tensorflow.python.debug.lib import source_utils
|
from tensorflow.python.debug.lib import source_utils
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import tf_logging
|
|
||||||
from tensorflow.python.profiler import tfprof_logger
|
from tensorflow.python.profiler import tfprof_logger
|
||||||
|
|
||||||
|
|
||||||
@ -96,11 +95,6 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
|
|||||||
return non_tf_files
|
return non_tf_files
|
||||||
|
|
||||||
|
|
||||||
def grpc_message_length_bytes():
|
|
||||||
"""Maximum gRPC message length in bytes."""
|
|
||||||
return 4 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _send_call_tracebacks(destinations,
|
def _send_call_tracebacks(destinations,
|
||||||
origin_stack,
|
origin_stack,
|
||||||
is_eager_execution=False,
|
is_eager_execution=False,
|
||||||
@ -169,20 +163,14 @@ def _send_call_tracebacks(destinations,
|
|||||||
debugged_source_files.append(source_files)
|
debugged_source_files.append(source_files)
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
channel = grpc.insecure_channel(destination)
|
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
|
||||||
|
("grpc.max_send_message_length", -1)]
|
||||||
|
channel = grpc.insecure_channel(destination, options=no_max_message_sizes)
|
||||||
stub = debug_service_pb2_grpc.EventListenerStub(channel)
|
stub = debug_service_pb2_grpc.EventListenerStub(channel)
|
||||||
stub.SendTracebacks(call_traceback)
|
stub.SendTracebacks(call_traceback)
|
||||||
if send_source:
|
if send_source:
|
||||||
for path, source_files in zip(
|
for source_files in debugged_source_files:
|
||||||
source_file_paths, debugged_source_files):
|
stub.SendSourceFiles(source_files)
|
||||||
if source_files.ByteSize() < grpc_message_length_bytes():
|
|
||||||
stub.SendSourceFiles(source_files)
|
|
||||||
else:
|
|
||||||
tf_logging.warn(
|
|
||||||
"The content of the source file at %s is not sent to "
|
|
||||||
"gRPC debug server %s, because the message size exceeds "
|
|
||||||
"gRPC message length limit (%d bytes)." % (
|
|
||||||
path, destination, grpc_message_length_bytes()))
|
|
||||||
|
|
||||||
|
|
||||||
def send_graph_tracebacks(destinations,
|
def send_graph_tracebacks(destinations,
|
||||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
from tensorflow.core.debug import debug_service_pb2
|
from tensorflow.core.debug import debug_service_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.debug.lib import grpc_debug_test_server
|
from tensorflow.python.debug.lib import grpc_debug_test_server
|
||||||
@ -129,9 +131,17 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
send_traceback = traceback.extract_stack()
|
send_traceback = traceback.extract_stack()
|
||||||
send_lineno = line_number_above()
|
send_lineno = line_number_above()
|
||||||
source_remote.send_graph_tracebacks(
|
|
||||||
[self._server_address, self._server_address_2],
|
with test.mock.patch.object(
|
||||||
"dummy_run_key", send_traceback, sess.graph)
|
grpc, "insecure_channel",
|
||||||
|
wraps=grpc.insecure_channel) as mock_grpc_channel:
|
||||||
|
source_remote.send_graph_tracebacks(
|
||||||
|
[self._server_address, self._server_address_2],
|
||||||
|
"dummy_run_key", send_traceback, sess.graph)
|
||||||
|
mock_grpc_channel.assert_called_with(
|
||||||
|
test.mock.ANY,
|
||||||
|
options=[("grpc.max_receive_message_length", -1),
|
||||||
|
("grpc.max_send_message_length", -1)])
|
||||||
|
|
||||||
servers = [self._server, self._server_2]
|
servers = [self._server, self._server_2]
|
||||||
for server in servers:
|
for server in servers:
|
||||||
@ -157,51 +167,6 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(["dummy_run_key"], server.query_call_keys())
|
self.assertEqual(["dummy_run_key"], server.query_call_keys())
|
||||||
self.assertEqual([sess.graph.version], server.query_graph_versions())
|
self.assertEqual([sess.graph.version], server.query_graph_versions())
|
||||||
|
|
||||||
def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
|
|
||||||
"""In case source file size exceeds the grpc message length limit.
|
|
||||||
|
|
||||||
it ought not to have been sent to the server.
|
|
||||||
"""
|
|
||||||
this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
|
|
||||||
|
|
||||||
# Patch the method to simulate a very small message length limit.
|
|
||||||
with test.mock.patch.object(
|
|
||||||
source_remote, "grpc_message_length_bytes", return_value=2):
|
|
||||||
with session.Session() as sess:
|
|
||||||
a = variables.Variable(21.0, name="two/a")
|
|
||||||
a_lineno = line_number_above()
|
|
||||||
b = variables.Variable(2.0, name="two/b")
|
|
||||||
b_lineno = line_number_above()
|
|
||||||
x = math_ops.add(a, b, name="two/x")
|
|
||||||
x_lineno = line_number_above()
|
|
||||||
|
|
||||||
send_traceback = traceback.extract_stack()
|
|
||||||
send_lineno = line_number_above()
|
|
||||||
source_remote.send_graph_tracebacks(
|
|
||||||
[self._server_address, self._server_address_2],
|
|
||||||
"dummy_run_key", send_traceback, sess.graph)
|
|
||||||
|
|
||||||
servers = [self._server, self._server_2]
|
|
||||||
for server in servers:
|
|
||||||
# Even though the source file content is not sent, the traceback
|
|
||||||
# should have been sent.
|
|
||||||
tb = server.query_op_traceback("two/a")
|
|
||||||
self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
|
|
||||||
tb = server.query_op_traceback("two/b")
|
|
||||||
self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
|
|
||||||
tb = server.query_op_traceback("two/x")
|
|
||||||
self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
|
|
||||||
|
|
||||||
self.assertIn(
|
|
||||||
(self._curr_file_path, send_lineno, this_func_name),
|
|
||||||
server.query_origin_stack()[-1])
|
|
||||||
|
|
||||||
tf_trace_file_path = (
|
|
||||||
self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
|
|
||||||
# Verify that the source content is not sent to the server.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
self._server.query_source_file_line(tf_trace_file_path, 0)
|
|
||||||
|
|
||||||
def testSendEagerTracebacksToSingleDebugServer(self):
|
def testSendEagerTracebacksToSingleDebugServer(self):
|
||||||
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
|
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
|
||||||
send_traceback = traceback.extract_stack()
|
send_traceback = traceback.extract_stack()
|
||||||
@ -213,6 +178,20 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertIn((self._curr_file_path, send_lineno, this_func_name),
|
self.assertIn((self._curr_file_path, send_lineno, this_func_name),
|
||||||
self._server.query_origin_stack()[-1])
|
self._server.query_origin_stack()[-1])
|
||||||
|
|
||||||
|
def testGRPCServerMessageSizeLimit(self):
|
||||||
|
"""Assert gRPC debug server is started with unlimited message size."""
|
||||||
|
with test.mock.patch.object(
|
||||||
|
grpc, "server", wraps=grpc.server) as mock_grpc_server:
|
||||||
|
(_, _, _, server_thread,
|
||||||
|
server) = grpc_debug_test_server.start_server_on_separate_thread(
|
||||||
|
poll_server=True)
|
||||||
|
mock_grpc_server.assert_called_with(
|
||||||
|
test.mock.ANY,
|
||||||
|
options=[("grpc.max_receive_message_length", -1),
|
||||||
|
("grpc.max_send_message_length", -1)])
|
||||||
|
server.stop_server().wait()
|
||||||
|
server_thread.join()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user