[tfdbg] Fix gRPC message length limit issue in source remote
Fixes https://github.com/tensorflow/tensorboard/issues/1103 PiperOrigin-RevId: 257419107
This commit is contained in:
parent
d7fbbc0023
commit
0f5d0cfc2a
@ -346,7 +346,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
if self._server_started:
|
||||
raise ValueError("Server has already started running")
|
||||
|
||||
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
|
||||
("grpc.max_send_message_length", -1)]
|
||||
self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
|
||||
options=no_max_message_sizes)
|
||||
debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
|
||||
self.server)
|
||||
self.server.add_insecure_port("[::]:%d" % self._server_port)
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.profiler import tfprof_logger
|
||||
|
||||
|
||||
@ -96,11 +95,6 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
|
||||
return non_tf_files
|
||||
|
||||
|
||||
def grpc_message_length_bytes():
|
||||
"""Maximum gRPC message length in bytes."""
|
||||
return 4 * 1024 * 1024
|
||||
|
||||
|
||||
def _send_call_tracebacks(destinations,
|
||||
origin_stack,
|
||||
is_eager_execution=False,
|
||||
@ -169,20 +163,14 @@ def _send_call_tracebacks(destinations,
|
||||
debugged_source_files.append(source_files)
|
||||
|
||||
for destination in destinations:
|
||||
channel = grpc.insecure_channel(destination)
|
||||
no_max_message_sizes = [("grpc.max_receive_message_length", -1),
|
||||
("grpc.max_send_message_length", -1)]
|
||||
channel = grpc.insecure_channel(destination, options=no_max_message_sizes)
|
||||
stub = debug_service_pb2_grpc.EventListenerStub(channel)
|
||||
stub.SendTracebacks(call_traceback)
|
||||
if send_source:
|
||||
for path, source_files in zip(
|
||||
source_file_paths, debugged_source_files):
|
||||
if source_files.ByteSize() < grpc_message_length_bytes():
|
||||
stub.SendSourceFiles(source_files)
|
||||
else:
|
||||
tf_logging.warn(
|
||||
"The content of the source file at %s is not sent to "
|
||||
"gRPC debug server %s, because the message size exceeds "
|
||||
"gRPC message length limit (%d bytes)." % (
|
||||
path, destination, grpc_message_length_bytes()))
|
||||
for source_files in debugged_source_files:
|
||||
stub.SendSourceFiles(source_files)
|
||||
|
||||
|
||||
def send_graph_tracebacks(destinations,
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import grpc
|
||||
|
||||
from tensorflow.core.debug import debug_service_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.lib import grpc_debug_test_server
|
||||
@ -129,9 +131,17 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
||||
|
||||
send_traceback = traceback.extract_stack()
|
||||
send_lineno = line_number_above()
|
||||
source_remote.send_graph_tracebacks(
|
||||
[self._server_address, self._server_address_2],
|
||||
"dummy_run_key", send_traceback, sess.graph)
|
||||
|
||||
with test.mock.patch.object(
|
||||
grpc, "insecure_channel",
|
||||
wraps=grpc.insecure_channel) as mock_grpc_channel:
|
||||
source_remote.send_graph_tracebacks(
|
||||
[self._server_address, self._server_address_2],
|
||||
"dummy_run_key", send_traceback, sess.graph)
|
||||
mock_grpc_channel.assert_called_with(
|
||||
test.mock.ANY,
|
||||
options=[("grpc.max_receive_message_length", -1),
|
||||
("grpc.max_send_message_length", -1)])
|
||||
|
||||
servers = [self._server, self._server_2]
|
||||
for server in servers:
|
||||
@ -157,51 +167,6 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(["dummy_run_key"], server.query_call_keys())
|
||||
self.assertEqual([sess.graph.version], server.query_graph_versions())
|
||||
|
||||
def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
|
||||
"""In case source file size exceeds the grpc message length limit.
|
||||
|
||||
it ought not to have been sent to the server.
|
||||
"""
|
||||
this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
|
||||
|
||||
# Patch the method to simulate a very small message length limit.
|
||||
with test.mock.patch.object(
|
||||
source_remote, "grpc_message_length_bytes", return_value=2):
|
||||
with session.Session() as sess:
|
||||
a = variables.Variable(21.0, name="two/a")
|
||||
a_lineno = line_number_above()
|
||||
b = variables.Variable(2.0, name="two/b")
|
||||
b_lineno = line_number_above()
|
||||
x = math_ops.add(a, b, name="two/x")
|
||||
x_lineno = line_number_above()
|
||||
|
||||
send_traceback = traceback.extract_stack()
|
||||
send_lineno = line_number_above()
|
||||
source_remote.send_graph_tracebacks(
|
||||
[self._server_address, self._server_address_2],
|
||||
"dummy_run_key", send_traceback, sess.graph)
|
||||
|
||||
servers = [self._server, self._server_2]
|
||||
for server in servers:
|
||||
# Even though the source file content is not sent, the traceback
|
||||
# should have been sent.
|
||||
tb = server.query_op_traceback("two/a")
|
||||
self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
|
||||
tb = server.query_op_traceback("two/b")
|
||||
self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
|
||||
tb = server.query_op_traceback("two/x")
|
||||
self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
|
||||
|
||||
self.assertIn(
|
||||
(self._curr_file_path, send_lineno, this_func_name),
|
||||
server.query_origin_stack()[-1])
|
||||
|
||||
tf_trace_file_path = (
|
||||
self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
|
||||
# Verify that the source content is not sent to the server.
|
||||
with self.assertRaises(ValueError):
|
||||
self._server.query_source_file_line(tf_trace_file_path, 0)
|
||||
|
||||
def testSendEagerTracebacksToSingleDebugServer(self):
|
||||
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
|
||||
send_traceback = traceback.extract_stack()
|
||||
@ -213,6 +178,20 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
|
||||
self.assertIn((self._curr_file_path, send_lineno, this_func_name),
|
||||
self._server.query_origin_stack()[-1])
|
||||
|
||||
def testGRPCServerMessageSizeLimit(self):
|
||||
"""Assert gRPC debug server is started with unlimited message size."""
|
||||
with test.mock.patch.object(
|
||||
grpc, "server", wraps=grpc.server) as mock_grpc_server:
|
||||
(_, _, _, server_thread,
|
||||
server) = grpc_debug_test_server.start_server_on_separate_thread(
|
||||
poll_server=True)
|
||||
mock_grpc_server.assert_called_with(
|
||||
test.mock.ANY,
|
||||
options=[("grpc.max_receive_message_length", -1),
|
||||
("grpc.max_send_message_length", -1)])
|
||||
server.stop_server().wait()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user