Add enter_master_device
flag.
This commit is contained in:
parent
69b1feac62
commit
733002ffed
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import remote_utils
|
from tensorflow.python.platform import remote_utils
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -36,18 +38,15 @@ _GRPC_PREFIX = "grpc://"
|
|||||||
@tf_export("config.experimental_connect_to_host")
|
@tf_export("config.experimental_connect_to_host")
|
||||||
def connect_to_remote_host(remote_host=None, job_name="worker"):
|
def connect_to_remote_host(remote_host=None, job_name="worker"):
|
||||||
"""Connects to a single machine to enable remote execution on it.
|
"""Connects to a single machine to enable remote execution on it.
|
||||||
|
|
||||||
Will make devices on the remote host available to use. Note that calling this
|
Will make devices on the remote host available to use. Note that calling this
|
||||||
more than once will work, but will invalidate any tensor handles on the old
|
more than once will work, but will invalidate any tensor handles on the old
|
||||||
remote devices.
|
remote devices.
|
||||||
|
|
||||||
Using the default job_name of worker, you can schedule ops to run remotely as
|
Using the default job_name of worker, you can schedule ops to run remotely as
|
||||||
follows:
|
follows:
|
||||||
```python
|
```python
|
||||||
# Enable eager execution, and connect to the remote host.
|
# Enable eager execution, and connect to the remote host.
|
||||||
tf.compat.v1.enable_eager_execution()
|
tf.compat.v1.enable_eager_execution()
|
||||||
tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
|
tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
|
||||||
|
|
||||||
with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
|
with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
|
||||||
# The following tensors should be resident on the remote device, and the op
|
# The following tensors should be resident on the remote device, and the op
|
||||||
# will also execute remotely.
|
# will also execute remotely.
|
||||||
@ -55,11 +54,9 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
|||||||
x2 = array_ops.ones([2, 2])
|
x2 = array_ops.ones([2, 2])
|
||||||
y = math_ops.matmul(x1, x2)
|
y = math_ops.matmul(x1, x2)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
remote_host: a single or a list the remote server addr in host-port format.
|
remote_host: a single or a list the remote server addr in host-port format.
|
||||||
job_name: The job name under which the new server will be accessible.
|
job_name: The job name under which the new server will be accessible.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if remote_host is None.
|
ValueError: if remote_host is None.
|
||||||
"""
|
"""
|
||||||
@ -77,16 +74,14 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
|
|||||||
def connect_to_cluster(cluster_spec_or_resolver,
|
def connect_to_cluster(cluster_spec_or_resolver,
|
||||||
job_name="localhost",
|
job_name="localhost",
|
||||||
task_index=0,
|
task_index=0,
|
||||||
protocol=None):
|
protocol=None,
|
||||||
|
make_master_device_default=True):
|
||||||
"""Connects to the given cluster.
|
"""Connects to the given cluster.
|
||||||
|
|
||||||
Will make devices on the cluster available to use. Note that calling this more
|
Will make devices on the cluster available to use. Note that calling this more
|
||||||
than once will work, but will invalidate any tensor handles on the old remote
|
than once will work, but will invalidate any tensor handles on the old remote
|
||||||
devices.
|
devices.
|
||||||
|
|
||||||
If the given local job name is not present in the cluster specification, it
|
If the given local job name is not present in the cluster specification, it
|
||||||
will be automatically added, using an unused port on the localhost.
|
will be automatically added, using an unused port on the localhost.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
|
cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
|
||||||
the cluster.
|
the cluster.
|
||||||
@ -94,6 +89,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
task_index: The local task index.
|
task_index: The local task index.
|
||||||
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
|
||||||
use the default from `python/platform/remote_utils.py`.
|
use the default from `python/platform/remote_utils.py`.
|
||||||
|
make_master_device_default: If True and a cluster resolver is passed, will
|
||||||
|
automatically enter the master task device scope, which indicates the
|
||||||
|
master becomes the default device to run ops. It won't do anything if
|
||||||
|
a cluster spec is passed. Will throw an error if the caller is currently
|
||||||
|
already in some device scope.
|
||||||
"""
|
"""
|
||||||
protocol = protocol or remote_utils.get_default_communication_protocol()
|
protocol = protocol or remote_utils.get_default_communication_protocol()
|
||||||
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
|
||||||
@ -124,6 +124,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
|||||||
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
|
||||||
context.set_server_def(server_def)
|
context.set_server_def(server_def)
|
||||||
|
|
||||||
|
if make_master_device_default and isinstance(
|
||||||
|
cluster_spec_or_resolver,
|
||||||
|
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
|
||||||
|
master = cluster_spec_or_resolver.master()
|
||||||
|
master_job_name = None
|
||||||
|
master_task_id = None
|
||||||
|
for job_name in cluster_spec.jobs:
|
||||||
|
for task_id in cluster_spec.task_indices(job_name):
|
||||||
|
task_address = cluster_spec.task_address(job_name, task_id)
|
||||||
|
if master in task_address or task_address in master:
|
||||||
|
master_job_name = job_name
|
||||||
|
master_task_id = task_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if not master_job_name:
|
||||||
|
raise ValueError(
|
||||||
|
"`make_master_device_default` is set to True but cannot find "
|
||||||
|
"master %s in the cluster" % master)
|
||||||
|
|
||||||
|
master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
|
||||||
|
master_task_id)
|
||||||
|
if not _device_stack_is_empty():
|
||||||
|
raise ValueError("`connect_to_cluster` should not be called inside "
|
||||||
|
"an existing device scope")
|
||||||
|
logging.info("Entering into master device scope: %s", master_device)
|
||||||
|
# TODO(b/138389076): Think of the entering device scope behavior in the
|
||||||
|
# failure recovery case when dealing with preemptions.
|
||||||
|
ops.device(master_device).__enter__()
|
||||||
|
|
||||||
|
|
||||||
def _strip_prefix(s, prefix):
|
def _strip_prefix(s, prefix):
|
||||||
return s[len(prefix):] if s.startswith(prefix) else s
|
return s[len(prefix):] if s.startswith(prefix) else s
|
||||||
|
|
||||||
|
|
||||||
|
def _device_stack_is_empty():
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return not bool(context.context().device_name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return not bool(device_stack)
|
||||||
|
Loading…
Reference in New Issue
Block a user