Move default protocol parameter to a platforms library.
PiperOrigin-RevId: 260998356
This commit is contained in:
parent
620fbe292f
commit
8d52e1c646
@ -24,6 +24,7 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.platform import remote_utils
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -73,11 +74,10 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||
|
||||
|
||||
@tf_export("config.experimental_connect_to_cluster")
|
||||
def connect_to_cluster(
|
||||
cluster_spec_or_resolver,
|
||||
job_name="localhost",
|
||||
task_index=0,
|
||||
protocol="grpc"):
|
||||
def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_name="localhost",
|
||||
task_index=0,
|
||||
protocol=None):
|
||||
"""Connects to the given cluster.
|
||||
|
||||
Will make devices on the cluster available to use. Note that calling this more
|
||||
@ -92,8 +92,10 @@ def connect_to_cluster(
|
||||
the cluster.
|
||||
job_name: The name of the local job.
|
||||
task_index: The local task index.
|
||||
protocol: The communication protocol.
|
||||
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||
use the default from `python/platform/remote_utils.py`.
|
||||
"""
|
||||
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||
cluster_spec = cluster_spec_or_resolver
|
||||
elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
|
||||
|
22
tensorflow/python/platform/remote_utils.py
Normal file
22
tensorflow/python/platform/remote_utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Platform-specific helpers for connecting to remote servers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def get_default_communication_protocol():
|
||||
return 'grpc'
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
Loading…
Reference in New Issue
Block a user