Add an enter_master_device
flag in tf.config.experimental_connect_to_cluster API.
PiperOrigin-RevId: 264925193
This commit is contained in:
parent
63198ea100
commit
2800f9b00b
@ -289,6 +289,11 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
else:
|
else:
|
||||||
# Strip numerical ports.
|
# Strip numerical ports.
|
||||||
self._tpu = bns_and_port[0]
|
self._tpu = bns_and_port[0]
|
||||||
|
|
||||||
|
# Remove '.brain' suffix.
|
||||||
|
# TODO(b/139700237): Support bns address with named port.
|
||||||
|
if self._tpu.endswith(compat.as_bytes('.brain')):
|
||||||
|
self._tpu = self._tpu[:-6]
|
||||||
else:
|
else:
|
||||||
self._environment = ''
|
self._environment = ''
|
||||||
self.rpc_layer = 'grpc'
|
self.rpc_layer = 'grpc'
|
||||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import remote_utils
|
from tensorflow.python.platform import remote_utils
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -77,7 +79,8 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
|||||||
def connect_to_cluster(cluster_spec_or_resolver,
|
def connect_to_cluster(cluster_spec_or_resolver,
|
||||||
job_name="localhost",
|
job_name="localhost",
|
||||||
task_index=0,
|
task_index=0,
|
||||||
protocol=None):
|
protocol=None,
|
||||||
|
make_master_device_default=True):
|
||||||
"""Connects to the given cluster.
|
"""Connects to the given cluster.
|
||||||
|
|
||||||
Will make devices on the cluster available to use. Note that calling this more
|
Will make devices on the cluster available to use. Note that calling this more
|
||||||
@ -94,6 +97,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
task_index: The local task index.
|
task_index: The local task index.
|
||||||
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||||
use the default from `python/platform/remote_utils.py`.
|
use the default from `python/platform/remote_utils.py`.
|
||||||
|
make_master_device_default: If True and a cluster resolver is passed, will
|
||||||
|
automatically enter the master task device scope, which indicates the
|
||||||
|
master becomes the default device to run ops. It won't do anything if
|
||||||
|
a cluster spec is passed. Will throw an error if the caller is currently
|
||||||
|
already in some device scope.
|
||||||
"""
|
"""
|
||||||
protocol = protocol or remote_utils.get_default_communication_protocol()
|
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||||
@ -124,6 +132,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||||
context.set_server_def(server_def)
|
context.set_server_def(server_def)
|
||||||
|
|
||||||
|
if make_master_device_default and isinstance(
|
||||||
|
cluster_spec_or_resolver,
|
||||||
|
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
|
||||||
|
master = cluster_spec_or_resolver.master()
|
||||||
|
master_job_name = None
|
||||||
|
master_task_id = None
|
||||||
|
for job_name in cluster_spec.jobs:
|
||||||
|
for task_id in cluster_spec.task_indices(job_name):
|
||||||
|
task_address = cluster_spec.task_address(job_name, task_id)
|
||||||
|
if master in task_address or task_address in master:
|
||||||
|
master_job_name = job_name
|
||||||
|
master_task_id = task_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if not master_job_name:
|
||||||
|
raise ValueError(
|
||||||
|
"`make_master_device_default` is set to True but cannot find "
|
||||||
|
"master %s in the cluster" % master)
|
||||||
|
|
||||||
|
master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
|
||||||
|
master_task_id)
|
||||||
|
if not _device_stack_is_empty():
|
||||||
|
raise ValueError("`connect_to_cluster` should not be called inside "
|
||||||
|
"an existing device scope")
|
||||||
|
logging.info("Entering into master device scope: %s", master_device)
|
||||||
|
# TODO(b/138389076): Think of the entering device scope behavior in the
|
||||||
|
# failure recovery case when dealing with preemptions.
|
||||||
|
ops.device(master_device).__enter__()
|
||||||
|
|
||||||
|
|
||||||
def _strip_prefix(s, prefix):
|
def _strip_prefix(s, prefix):
|
||||||
return s[len(prefix):] if s.startswith(prefix) else s
|
return s[len(prefix):] if s.startswith(prefix) else s
|
||||||
|
|
||||||
|
|
||||||
|
def _device_stack_is_empty():
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return not bool(context.context().device_name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return not bool(device_stack)
|
||||||
|
@ -14,7 +14,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_cluster"
|
name: "experimental_connect_to_cluster"
|
||||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_host"
|
name: "experimental_connect_to_host"
|
||||||
|
@ -14,7 +14,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_cluster"
|
name: "experimental_connect_to_cluster"
|
||||||
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
|
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_connect_to_host"
|
name: "experimental_connect_to_host"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user