Move default protocol parameter to a platforms library.

PiperOrigin-RevId: 260998356
This commit is contained in:
Revan Sopher 2019-07-31 14:26:29 -07:00 committed by TensorFlower Gardener
parent 620fbe292f
commit 8d52e1c646
4 changed files with 32 additions and 8 deletions

View File

@ -24,6 +24,7 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.platform import remote_utils
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -73,11 +74,10 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
@tf_export("config.experimental_connect_to_cluster")
def connect_to_cluster(
cluster_spec_or_resolver,
job_name="localhost",
task_index=0,
protocol="grpc"):
def connect_to_cluster(cluster_spec_or_resolver,
job_name="localhost",
task_index=0,
protocol=None):
"""Connects to the given cluster.
Will make devices on the cluster available to use. Note that calling this more
@ -92,8 +92,10 @@ def connect_to_cluster(
the cluster.
job_name: The name of the local job.
task_index: The local task index.
protocol: The communication protocol.
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
use the default from `python/platform/remote_utils.py`.
"""
protocol = protocol or remote_utils.get_default_communication_protocol()
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
cluster_spec = cluster_spec_or_resolver
elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):

View File

@ -0,0 +1,22 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Platform-specific helpers for connecting to remote servers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def get_default_communication_protocol():
return 'grpc'

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
}
member_method {
name: "experimental_connect_to_host"

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
}
member_method {
name: "experimental_connect_to_host"