From 8d52e1c6460d7e0809afd1883b6f2eb21403b1f5 Mon Sep 17 00:00:00 2001 From: Revan Sopher Date: Wed, 31 Jul 2019 14:26:29 -0700 Subject: [PATCH] Move default protocol parameter to a platforms library. PiperOrigin-RevId: 260998356 --- tensorflow/python/eager/remote.py | 14 +++++++----- tensorflow/python/platform/remote_utils.py | 22 +++++++++++++++++++ .../api/golden/v1/tensorflow.config.pbtxt | 2 +- .../api/golden/v2/tensorflow.config.pbtxt | 2 +- 4 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 tensorflow/python/platform/remote_utils.py diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py index cccec010e08..15dec68aec3 100644 --- a/tensorflow/python/eager/remote.py +++ b/tensorflow/python/eager/remote.py @@ -24,6 +24,7 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.python import pywrap_tensorflow from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.eager import context +from tensorflow.python.platform import remote_utils from tensorflow.python.training import server_lib from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -73,11 +74,10 @@ def connect_to_remote_host(remote_host=None, job_name="worker"): @tf_export("config.experimental_connect_to_cluster") -def connect_to_cluster( - cluster_spec_or_resolver, - job_name="localhost", - task_index=0, - protocol="grpc"): +def connect_to_cluster(cluster_spec_or_resolver, + job_name="localhost", + task_index=0, + protocol=None): """Connects to the given cluster. Will make devices on the cluster available to use. Note that calling this more @@ -92,8 +92,10 @@ def connect_to_cluster( the cluster. job_name: The name of the local job. task_index: The local task index. - protocol: The communication protocol. + protocol: The communication protocol, such as `"grpc"`. If unspecified, will + use the default from `python/platform/remote_utils.py`. """ + protocol = protocol or remote_utils.get_default_communication_protocol() if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec): cluster_spec = cluster_spec_or_resolver elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver): diff --git a/tensorflow/python/platform/remote_utils.py b/tensorflow/python/platform/remote_utils.py new file mode 100644 index 00000000000..9ec2e5e5ef8 --- /dev/null +++ b/tensorflow/python/platform/remote_utils.py @@ -0,0 +1,22 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Platform-specific helpers for connecting to remote servers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def get_default_communication_protocol(): + return 'grpc' diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt index 0c29d7a0594..cc188a1e952 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], " } member_method { name: "experimental_connect_to_host" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt index 0c29d7a0594..cc188a1e952 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.pbtxt @@ -14,7 +14,7 @@ tf_module { } member_method { name: "experimental_connect_to_cluster" - argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], " + argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], " } member_method { name: "experimental_connect_to_host"