Add an enter_master_device flag in tf.config.experimental_connect_to_cluster API.

PiperOrigin-RevId: 264925193
This commit is contained in:
Ruoxin Sang 2019-08-22 14:52:21 -07:00 committed by TensorFlower Gardener
parent 63198ea100
commit 2800f9b00b
4 changed files with 54 additions and 3 deletions

View File

@ -289,6 +289,11 @@ class TPUClusterResolver(ClusterResolver):
else:
# Strip numerical ports.
self._tpu = bns_and_port[0]
# Remove '.brain' suffix.
# TODO(b/139700237): Support bns address with named port.
if self._tpu.endswith(compat.as_bytes('.brain')):
self._tpu = self._tpu[:-6]
else:
self._environment = ''
self.rpc_layer = 'grpc'

View File

@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import os
from absl import logging
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.platform import remote_utils
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@ -77,7 +79,8 @@ def connect_to_remote_host(remote_host=None, job_name="worker"):
def connect_to_cluster(cluster_spec_or_resolver,
job_name="localhost",
task_index=0,
protocol=None):
protocol=None,
make_master_device_default=True):
"""Connects to the given cluster.
Will make devices on the cluster available to use. Note that calling this more
@ -94,6 +97,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
task_index: The local task index.
protocol: The communication protocol, such as `"grpc"`. If unspecified, will
use the default from `python/platform/remote_utils.py`.
make_master_device_default: If True and a cluster resolver is passed, will
automatically enter the master task device scope, which indicates the
master becomes the default device to run ops. It won't do anything if
a cluster spec is passed. Will throw an error if the caller is currently
already in some device scope.
"""
protocol = protocol or remote_utils.get_default_communication_protocol()
if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
@ -124,6 +132,44 @@ def connect_to_cluster(cluster_spec_or_resolver,
os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
context.set_server_def(server_def)
if make_master_device_default and isinstance(
cluster_spec_or_resolver,
cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
master = cluster_spec_or_resolver.master()
master_job_name = None
master_task_id = None
for job_name in cluster_spec.jobs:
for task_id in cluster_spec.task_indices(job_name):
task_address = cluster_spec.task_address(job_name, task_id)
if master in task_address or task_address in master:
master_job_name = job_name
master_task_id = task_id
break
if not master_job_name:
raise ValueError(
"`make_master_device_default` is set to True but cannot find "
"master %s in the cluster" % master)
master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
master_task_id)
if not _device_stack_is_empty():
raise ValueError("`connect_to_cluster` should not be called inside "
"an existing device scope")
logging.info("Entering into master device scope: %s", master_device)
# TODO(b/138389076): Think of the entering device scope behavior in the
# failure recovery case when dealing with preemptions.
ops.device(master_device).__enter__()
def _strip_prefix(s, prefix):
return s[len(prefix):] if s.startswith(prefix) else s
def _device_stack_is_empty():
if context.executing_eagerly():
return not bool(context.context().device_name)
# pylint: disable=protected-access
device_stack = ops.get_default_graph()._device_functions_outer_to_inner
# pylint: enable=protected-access
return not bool(device_stack)

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
}
member_method {
name: "experimental_connect_to_host"

View File

@ -14,7 +14,7 @@ tf_module {
}
member_method {
name: "experimental_connect_to_cluster"
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\'], "
argspec: "args=[\'cluster_spec_or_resolver\', \'job_name\', \'task_index\', \'protocol\', \'make_master_device_default\'], varargs=None, keywords=None, defaults=[\'localhost\', \'0\', \'None\', \'True\'], "
}
member_method {
name: "experimental_connect_to_host"