Move default protocol parameter to a platforms library.
PiperOrigin-RevId: 260998356
This commit is contained in:
parent
620fbe292f
commit
8d52e1c646
@ -24,6 +24,7 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
|||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.platform import remote_utils
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -73,11 +74,10 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("config.experimental_connect_to_cluster")
|
@tf_export("config.experimental_connect_to_cluster")
|
||||||
def connect_to_cluster(
|
def connect_to_cluster(cluster_spec_or_resolver,
|
||||||
cluster_spec_or_resolver,
|
job_name="localhost",
|
||||||
job_name="localhost",
|
task_index=0,
|
||||||
task_index=0,
|
protocol=None):
|
||||||
protocol="grpc"):
|
|
||||||
"""Connects to the given cluster.
|
"""Connects to the given cluster.
|
||||||
|
|
||||||
Will make devices on the cluster available to use. Note that calling this more
|
Will make devices on the cluster available to use. Note that calling this more
|
||||||
@ -92,8 +92,10 @@ def connect_to_cluster(
|
|||||||
the cluster.
|
the cluster.
|
||||||
job_name: The name of the local job.
|
job_name: The name of the local job.
|
||||||
task_index: The local task index.
|
task_index: The local task index.
|
||||||
protocol: The communication protocol.
|
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||||
|
use the default from `python/platform/remote_utils.py`.
|
||||||
"""
|
"""
|
||||||
|
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||||
cluster_spec = cluster_spec_or_resolver
|
cluster_spec = cluster_spec_or_resolver
|
||||||
elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
|
elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
|
||||||
|
22
tensorflow/python/platform/remote_utils.py
Normal file
22
tensorflow/python/platform/remote_utils.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Platform-specific helpers for connecting to remote servers."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_communication_protocol():
|
||||||
|
return 'grpc'
|
@ -14,7 +14,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_cluster"
|
name: "experimental_connect_to_cluster"
|
||||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
|
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_host"
|
name: "experimental_connect_to_host"
|
||||||
|
@ -14,7 +14,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_cluster"
|
name: "experimental_connect_to_cluster"
|
||||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'grpc\'], "
|
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_host"
|
name: "experimental_connect_to_host"
|
||||||
|
Loading…
Reference in New Issue
Block a user