Add enter_master_device
flag.
This commit is contained in:
parent
69b1feac62
commit
733002ffed
@ -19,11 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import remote_utils
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
@ -36,18 +38,15 @@ _GRPC_PREFIX = "grpc://"
|
||||
@tf_export("config.experimental_connect_to_host")
|
||||
def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||
"""Connects to a single machine to enable remote execution on it.
|
||||
|
||||
Will make devices on the remote host available to use. Note that calling this
|
||||
more than once will work, but will invalidate any tensor handles on the old
|
||||
remote devices.
|
||||
|
||||
Using the default job_name of worker, you can schedule ops to run remotely as
|
||||
follows:
|
||||
```python
|
||||
# Enable eager execution, and connect to the remote host.
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
|
||||
|
||||
with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
|
||||
# The following tensors should be resident on the remote device, and the op
|
||||
# will also execute remotely.
|
||||
@ -55,11 +54,9 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||
x2 = array_ops.ones([2, 2])
|
||||
y = math_ops.matmul(x1, x2)
|
||||
```
|
||||
|
||||
Args:
|
||||
remote_host: a single or a list the remote server addr in host-port format.
|
||||
job_name: The job name under which the new server will be accessible.
|
||||
|
||||
Raises:
|
||||
ValueError: if remote_host is None.
|
||||
"""
|
||||
@ -77,16 +74,14 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||
def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_name="localhost",
|
||||
task_index=0,
|
||||
protocol=None):
|
||||
protocol=None,
|
||||
make_master_device_default=True):
|
||||
"""Connects to the given cluster.
|
||||
|
||||
Will make devices on the cluster available to use. Note that calling this more
|
||||
than once will work, but will invalidate any tensor handles on the old remote
|
||||
devices.
|
||||
|
||||
If the given local job name is not present in the cluster specification, it
|
||||
will be automatically added, using an unused port on the localhost.
|
||||
|
||||
Args:
|
||||
cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
|
||||
the cluster.
|
||||
@ -94,6 +89,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
task_index: The local task index.
|
||||
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||
use the default from `python/platform/remote_utils.py`.
|
||||
make_master_device_default: If True and a cluster resolver is passed, will
|
||||
automatically enter the master task device scope, which indicates the
|
||||
master becomes the default device to run ops. It won't do anything if
|
||||
a cluster spec is passed. Will throw an error if the caller is currently
|
||||
already in some device scope.
|
||||
"""
|
||||
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||
@ -124,6 +124,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||
context.set_server_def(server_def)
|
||||
|
||||
if make_master_device_default and isinstance(
|
||||
cluster_spec_or_resolver,
|
||||
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
|
||||
master = cluster_spec_or_resolver.master()
|
||||
master_job_name = None
|
||||
master_task_id = None
|
||||
for job_name in cluster_spec.jobs:
|
||||
for task_id in cluster_spec.task_indices(job_name):
|
||||
task_address = cluster_spec.task_address(job_name, task_id)
|
||||
if master in task_address or task_address in master:
|
||||
master_job_name = job_name
|
||||
master_task_id = task_id
|
||||
break
|
||||
|
||||
if not master_job_name:
|
||||
raise ValueError(
|
||||
"`make_master_device_default` is set to True but cannot find "
|
||||
"master %s in the cluster" % master)
|
||||
|
||||
master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
|
||||
master_task_id)
|
||||
if not _device_stack_is_empty():
|
||||
raise ValueError("`connect_to_cluster` should not be called inside "
|
||||
"an existing device scope")
|
||||
logging.info("Entering into master device scope: %s", master_device)
|
||||
# TODO(b/138389076): Think of the entering device scope behavior in the
|
||||
# failure recovery case when dealing with preemptions.
|
||||
ops.device(master_device).__enter__()
|
||||
|
||||
|
||||
def _strip_prefix(s, prefix):
|
||||
return s[len(prefix):] if s.startswith(prefix) else s
|
||||
|
||||
|
||||
def _device_stack_is_empty():
|
||||
if context.executing_eagerly():
|
||||
return not bool(context.context().device_name)
|
||||
# pylint: disable=protected-access
|
||||
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
|
||||
# pylint: enable=protected-access
|
||||
return not bool(device_stack)
|
||||
|
Loading…
Reference in New Issue
Block a user