From 3af03be757b63ea6fbd28cc351d5d2323c526354 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Tue, 2 May 2017 18:56:32 -0800
Subject: [PATCH] tfdbg: internal-only changes Change: 154914490

---
 tensorflow/python/debug/wrappers/framework.py        |  6 ------
 tensorflow/python/debug/wrappers/framework_test.py   | 12 ------------
 tensorflow/tools/dist_test/server/BUILD              |  2 +-
 .../tools/dist_test/server/grpc_tensorflow_server.py | 12 +++++++++++-
 4 files changed, 12 insertions(+), 20 deletions(-)
 mode change 100755 => 100644 tensorflow/tools/dist_test/server/grpc_tensorflow_server.py

diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 50645c1c874..0d8616a69fb 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -348,12 +348,6 @@ class BaseDebugWrapperSession(session.SessionInterface):
 
     _check_type(sess, session.BaseSession)
 
-    # TODO(cais): Remove this check once tfdbg is integrated with GrpcSession.
-    if sess.sess_str:
-      raise NotImplementedError(
-          "Non-DirectSession support is not available from TensorFlow "
-          "Debugger yet (sess_str=%s)" % sess.sess_str)
-
     # The session being wrapped.
     self._sess = sess
     self._thread_name_filter_pattern = (re.compile(thread_name_filter)
diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py
index 1d69c7769a2..fd0efcd925f 100644
--- a/tensorflow/python/debug/wrappers/framework_test.py
+++ b/tensorflow/python/debug/wrappers/framework_test.py
@@ -384,18 +384,6 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase):
         ["a_init", "b_init"],
         [datum.node_name for datum in dump.dumped_tensor_data])
 
-  def testUsingNonDirectSessionRaisesNotImplementedError(self):
-    # TODO(cais): Remove this test once tfdbg is integrated with GrpcSession.
-    fake_non_direct_session = session.Session()
-    fake_non_direct_session._target = "foo"
-
-    with self.assertRaisesRegexp(
-        NotImplementedError,
-        r"Non-DirectSession support is not available from TensorFlow Debugger "
-        r"yet \(sess_str=foo\)"):
-      TestDebugWrapperSession(
-          fake_non_direct_session, self._dump_root, self._observer)
-
 
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 9d008ec9ce5..865af8dd7b2 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -9,7 +9,7 @@ exports_files(["LICENSE"])
 
 load("//tensorflow:tensorflow.bzl", "py_test")
 
-py_library(
+py_binary(
     name = "grpc_tensorflow_server",
     srcs = [
         "grpc_tensorflow_server.py",
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
old mode 100755
new mode 100644
index 2d774577b6d..bd6700a0b1f
--- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -36,6 +36,7 @@ from __future__ import print_function
 import argparse
 import sys
 
+from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import tensorflow_server_pb2
 from tensorflow.python.platform import app
 from tensorflow.python.training import server_lib
@@ -103,8 +104,11 @@ def main(unused_args):
     raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
   server_def.task_index = FLAGS.task_id
 
+  config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
+      per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))
+
   # Create GRPC Server instance
-  server = server_lib.Server(server_def)
+  server = server_lib.Server(server_def, config=config)
 
   # join() is blocking, unlike start()
   server.join()
@@ -137,6 +141,11 @@ if __name__ == "__main__":
       default=0,
       help="Task index, e.g., 0"
   )
+  parser.add_argument(
+      "--gpu_memory_fraction",
+      type=float,
+      default=1.0,
+      help="Fraction of GPU memory allocated",)
   parser.add_argument(
       "--verbose",
       type="bool",
@@ -145,5 +154,6 @@ if __name__ == "__main__":
       default=False,
       help="Verbose mode"
   )
+
   FLAGS, unparsed = parser.parse_known_args()
   app.run(main=main, argv=[sys.argv[0]] + unparsed)