From 3af03be757b63ea6fbd28cc351d5d2323c526354 Mon Sep 17 00:00:00 2001 From: Shanqing Cai <cais@google.com> Date: Tue, 2 May 2017 18:56:32 -0800 Subject: [PATCH] tfdbg: internal-only changes Change: 154914490 --- tensorflow/python/debug/wrappers/framework.py | 6 ------ tensorflow/python/debug/wrappers/framework_test.py | 12 ------------ tensorflow/tools/dist_test/server/BUILD | 2 +- .../tools/dist_test/server/grpc_tensorflow_server.py | 12 +++++++++++- 4 files changed, 12 insertions(+), 20 deletions(-) mode change 100755 => 100644 tensorflow/tools/dist_test/server/grpc_tensorflow_server.py diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 50645c1c874..0d8616a69fb 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -348,12 +348,6 @@ class BaseDebugWrapperSession(session.SessionInterface): _check_type(sess, session.BaseSession) - # TODO(cais): Remove this check once tfdbg is integrated with GrpcSession. - if sess.sess_str: - raise NotImplementedError( - "Non-DirectSession support is not available from TensorFlow " - "Debugger yet (sess_str=%s)" % sess.sess_str) - # The session being wrapped. self._sess = sess self._thread_name_filter_pattern = (re.compile(thread_name_filter) diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index 1d69c7769a2..fd0efcd925f 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -384,18 +384,6 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase): ["a_init", "b_init"], [datum.node_name for datum in dump.dumped_tensor_data]) - def testUsingNonDirectSessionRaisesNotImplementedError(self): - # TODO(cais): Remove this test once tfdbg is integrated with GrpcSession. - fake_non_direct_session = session.Session() - fake_non_direct_session._target = "foo" - - with self.assertRaisesRegexp( - NotImplementedError, - r"Non-DirectSession support is not available from TensorFlow Debugger " - r"yet \(sess_str=foo\)"): - TestDebugWrapperSession( - fake_non_direct_session, self._dump_root, self._observer) - if __name__ == "__main__": googletest.main() diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD index 9d008ec9ce5..865af8dd7b2 100644 --- a/tensorflow/tools/dist_test/server/BUILD +++ b/tensorflow/tools/dist_test/server/BUILD @@ -9,7 +9,7 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") -py_library( +py_binary( name = "grpc_tensorflow_server", srcs = [ "grpc_tensorflow_server.py", diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py old mode 100755 new mode 100644 index 2d774577b6d..bd6700a0b1f --- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py +++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py @@ -36,6 +36,7 @@ from __future__ import print_function import argparse import sys +from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.platform import app from tensorflow.python.training import server_lib @@ -103,8 +104,11 @@ def main(unused_args): raise ValueError("Invalid task_id: %d" % FLAGS.task_id) server_def.task_index = FLAGS.task_id + config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( + per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)) + # Create GRPC Server instance - server = server_lib.Server(server_def) + server = server_lib.Server(server_def, config=config) # join() is blocking, unlike start() server.join() @@ -137,6 +141,11 @@ if __name__ == "__main__": default=0, help="Task index, e.g., 0" ) + parser.add_argument( + "--gpu_memory_fraction", + type=float, + default=1.0, + help="Fraction of GPU memory allocated",) parser.add_argument( "--verbose", type="bool", @@ -145,5 +154,6 @@ if __name__ == "__main__": default=False, help="Verbose mode" ) + FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed)