Add an enter_master_device
flag in tf.config.experimental_connect_to_cluster API.
PiperOrigin-RevId: 264925193
This commit is contained in:
parent
63198ea100
commit
2800f9b00b
@ -289,6 +289,11 @@ class TPUClusterResolver(ClusterResolver):
|
||||
else:
|
||||
# Strip numerical ports.
|
||||
self._tpu = bns_and_port[0]
|
||||
|
||||
# Remove '.brain' suffix.
|
||||
# TODO(b/139700237): Support bns address with named port.
|
||||
if self._tpu.endswith(compat.as_bytes('.brain')):
|
||||
self._tpu = self._tpu[:-6]
|
||||
else:
|
||||
self._environment = ''
|
||||
self.rpc_layer = 'grpc'
|
||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import remote_utils
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
@ -77,7 +79,8 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||
def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_name="localhost",
|
||||
task_index=0,
|
||||
protocol=None):
|
||||
protocol=None,
|
||||
make_master_device_default=True):
|
||||
"""Connects to the given cluster.
|
||||
|
||||
Will make devices on the cluster available to use. Note that calling this more
|
||||
@ -94,6 +97,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
task_index: The local task index.
|
||||
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||
use the default from `python/platform/remote_utils.py`.
|
||||
make_master_device_default: If True and a cluster resolver is passed, will
|
||||
automatically enter the master task device scope, which indicates the
|
||||
master becomes the default device to run ops. It won't do anything if
|
||||
a cluster spec is passed. Will throw an error if the caller is currently
|
||||
already in some device scope.
|
||||
"""
|
||||
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||
@ -124,6 +132,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||
context.set_server_def(server_def)
|
||||
|
||||
if make_master_device_default and isinstance(
|
||||
cluster_spec_or_resolver,
|
||||
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
|
||||
master = cluster_spec_or_resolver.master()
|
||||
master_job_name = None
|
||||
master_task_id = None
|
||||
for job_name in cluster_spec.jobs:
|
||||
for task_id in cluster_spec.task_indices(job_name):
|
||||
task_address = cluster_spec.task_address(job_name, task_id)
|
||||
if master in task_address or task_address in master:
|
||||
master_job_name = job_name
|
||||
master_task_id = task_id
|
||||
break
|
||||
|
||||
if not master_job_name:
|
||||
raise ValueError(
|
||||
"`make_master_device_default` is set to True but cannot find "
|
||||
"master %s in the cluster" % master)
|
||||
|
||||
master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
|
||||
master_task_id)
|
||||
if not _device_stack_is_empty():
|
||||
raise ValueError("`connect_to_cluster` should not be called inside "
|
||||
"an existing device scope")
|
||||
logging.info("Entering into master device scope: %s", master_device)
|
||||
# TODO(b/138389076): Think of the entering device scope behavior in the
|
||||
# failure recovery case when dealing with preemptions.
|
||||
ops.device(master_device).__enter__()
|
||||
|
||||
|
||||
def _strip_prefix(s, prefix):
|
||||
return s[len(prefix):] if s.startswith(prefix) else s
|
||||
|
||||
|
||||
def _device_stack_is_empty():
|
||||
if context.executing_eagerly():
|
||||
return not bool(context.context().device_name)
|
||||
# pylint: disable=protected-access
|
||||
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
|
||||
# pylint: enable=protected-access
|
||||
return not bool(device_stack)
|
||||
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
@ -14,7 +14,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_cluster"
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_connect_to_host"
|
||||
|
Loading…
Reference in New Issue
Block a user