[tfdbg] Fix gRPC message length limit issue in source remote

Fixes https://github.com/tensorflow/tensorboard/issues/1103

PiperOrigin-RevId: 257419107
This commit is contained in:
Shanqing Cai 2019-07-10 09:25:53 -07:00 committed by TensorFlower Gardener
parent d7fbbc0023
commit 0f5d0cfc2a
3 changed files with 36 additions and 66 deletions

View File

@ -346,7 +346,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
if self._server_started:
raise ValueError("Server has already started running")
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)]
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
options=no_max_message_sizes)
debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
self.server)
self.server.add_insecure_port("[::]:%d" % self._server_port)

View File

@ -28,7 +28,6 @@ from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.profiler import tfprof_logger
@ -96,11 +95,6 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
return non_tf_files
def grpc_message_length_bytes():
"""Maximum gRPC message length in bytes."""
return 4 * 1024 * 1024
def _send_call_tracebacks(destinations,
origin_stack,
is_eager_execution=False,
@ -169,20 +163,14 @@ def _send_call_tracebacks(destinations,
debugged_source_files.append(source_files)
for destination in destinations:
channel = grpc.insecure_channel(destination)
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)]
channel = grpc.insecure_channel(destination, options=no_max_message_sizes)
stub = debug_service_pb2_grpc.EventListenerStub(channel)
stub.SendTracebacks(call_traceback)
if send_source:
for path, source_files in zip(
source_file_paths, debugged_source_files):
if source_files.ByteSize() < grpc_message_length_bytes():
stub.SendSourceFiles(source_files)
else:
tf_logging.warn(
"The content of the source file at %s is not sent to "
"gRPC debug server %s, because the message size exceeds "
"gRPC message length limit (%d bytes)." % (
path, destination, grpc_message_length_bytes()))
for source_files in debugged_source_files:
stub.SendSourceFiles(source_files)
def send_graph_tracebacks(destinations,

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import os
import traceback
import grpc
from tensorflow.core.debug import debug_service_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import grpc_debug_test_server
@ -129,9 +131,17 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
send_traceback = traceback.extract_stack()
send_lineno = line_number_above()
source_remote.send_graph_tracebacks(
[self._server_address, self._server_address_2],
"dummy_run_key", send_traceback, sess.graph)
with test.mock.patch.object(
grpc, "insecure_channel",
wraps=grpc.insecure_channel) as mock_grpc_channel:
source_remote.send_graph_tracebacks(
[self._server_address, self._server_address_2],
"dummy_run_key", send_traceback, sess.graph)
mock_grpc_channel.assert_called_with(
test.mock.ANY,
options=[("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)])
servers = [self._server, self._server_2]
for server in servers:
@ -157,51 +167,6 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
self.assertEqual(["dummy_run_key"], server.query_call_keys())
self.assertEqual([sess.graph.version], server.query_graph_versions())
def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
"""In case source file size exceeds the grpc message length limit.
it ought not to have been sent to the server.
"""
this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
# Patch the method to simulate a very small message length limit.
with test.mock.patch.object(
source_remote, "grpc_message_length_bytes", return_value=2):
with session.Session() as sess:
a = variables.Variable(21.0, name="two/a")
a_lineno = line_number_above()
b = variables.Variable(2.0, name="two/b")
b_lineno = line_number_above()
x = math_ops.add(a, b, name="two/x")
x_lineno = line_number_above()
send_traceback = traceback.extract_stack()
send_lineno = line_number_above()
source_remote.send_graph_tracebacks(
[self._server_address, self._server_address_2],
"dummy_run_key", send_traceback, sess.graph)
servers = [self._server, self._server_2]
for server in servers:
# Even though the source file content is not sent, the traceback
# should have been sent.
tb = server.query_op_traceback("two/a")
self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
tb = server.query_op_traceback("two/b")
self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
tb = server.query_op_traceback("two/x")
self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
self.assertIn(
(self._curr_file_path, send_lineno, this_func_name),
server.query_origin_stack()[-1])
tf_trace_file_path = (
self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
# Verify that the source content is not sent to the server.
with self.assertRaises(ValueError):
self._server.query_source_file_line(tf_trace_file_path, 0)
def testSendEagerTracebacksToSingleDebugServer(self):
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
send_traceback = traceback.extract_stack()
@ -213,6 +178,20 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
self.assertIn((self._curr_file_path, send_lineno, this_func_name),
self._server.query_origin_stack()[-1])
def testGRPCServerMessageSizeLimit(self):
"""Assert gRPC debug server is started with unlimited message size."""
with test.mock.patch.object(
grpc, "server", wraps=grpc.server) as mock_grpc_server:
(_, _, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
poll_server=True)
mock_grpc_server.assert_called_with(
test.mock.ANY,
options=[("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1)])
server.stop_server().wait()
server_thread.join()
if __name__ == "__main__":
googletest.main()