Moves ClusterResolvers into tensorflow.python.distribute in preparation for TensorFlow 2.0

PiperOrigin-RevId: 223401165
This commit is contained in:
Frank Chen 2018-11-29 13:30:32 -08:00 committed by TensorFlower Gardener
parent fd7b50ee62
commit a26f3b0598
29 changed files with 2014 additions and 1618 deletions

View File

@ -21,85 +21,18 @@ py_library(
py_library(
name = "cluster_resolver_py",
srcs = [
srcs = glob([
"__init__.py",
"python/training/__init__.py",
],
"python/training/*.py",
]),
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":base_cluster_resolver_py",
":gce_cluster_resolver_py",
":kubernetes_cluster_resolver_py",
":slurm_cluster_resolver_py",
":tfconfig_cluster_resolver_py",
":tpu_cluster_resolver_py",
"//tensorflow/python:util",
],
)
py_library(
name = "base_cluster_resolver_py",
srcs = ["python/training/cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:training",
],
)
py_library(
name = "gce_cluster_resolver_py",
srcs = ["python/training/gce_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
py_library(
name = "tfconfig_cluster_resolver_py",
srcs = ["python/training/tfconfig_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["python/training/tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
py_library(
name = "slurm_cluster_resolver_py",
srcs = ["python/training/slurm_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
)
py_library(
name = "kubernetes_cluster_resolver_py",
srcs = ["python/training/kubernetes_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training",
],
deps = ["//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib"],
)
tf_py_test(
name = "base_cluster_resolver_py_test",
srcs = ["python/training/cluster_resolver_test.py"],
name = "cluster_resolver_initialization_test",
srcs = ["cluster_resolver_initialization_test.py"],
additional_deps = [
":cluster_resolver_py",
"//tensorflow/python:client_testlib",
@ -108,86 +41,5 @@ tf_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
main = "python/training/cluster_resolver_test.py",
)
tf_py_test(
name = "gce_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/gce_cluster_resolver_test.py"],
additional_deps = [
":cluster_resolver_py",
":gce_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
main = "python/training/gce_cluster_resolver_test.py",
)
tf_py_test(
name = "tfconfig_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/tfconfig_cluster_resolver_test.py"],
additional_deps = [
":tfconfig_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
grpc_enabled = True,
main = "python/training/tfconfig_cluster_resolver_test.py",
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/tpu_cluster_resolver_test.py"],
additional_deps = [
":tpu_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
grpc_enabled = True,
main = "python/training/tpu_cluster_resolver_test.py",
)
tf_py_test(
name = "slurm_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/slurm_cluster_resolver_test.py"],
additional_deps = [
":cluster_resolver_py",
":slurm_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
main = "python/training/slurm_cluster_resolver_test.py",
tags = [],
)
tf_py_test(
name = "kubernetes_cluster_resolver_py_test",
size = "small",
srcs = ["python/training/kubernetes_cluster_resolver_test.py"],
additional_deps = [
":cluster_resolver_py",
":kubernetes_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
main = "python/training/kubernetes_cluster_resolver_test.py",
main = "cluster_resolver_initialization_test.py",
)

View File

@ -20,14 +20,14 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -0,0 +1,53 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests to ensure ClusterResolvers are usable via the old contrib path."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training import cluster_resolver
from tensorflow.contrib.cluster_resolver.python.training import UnionClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
class ClusterResolverInitializationTest(test.TestCase):
def testCreateSimpleClusterResolverFromLib(self):
base_cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
cluster_resolver.SimpleClusterResolver(base_cluster_spec)
def testCreateSimpleClusterResolver(self):
base_cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
SimpleClusterResolver(base_cluster_spec)
def testCreateUnionClusterResolver(self):
base_cluster_spec = server_lib.ClusterSpec({
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
simple_cr = SimpleClusterResolver(base_cluster_spec)
UnionClusterResolver(simple_cr)
if __name__ == "__main__":
test.main()

View File

@ -18,11 +18,36 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'cluster_resolver',
'gce_cluster_resolver',
'kubernetes_cluster_resolver',
'slurm_cluster_resolver',
'tfconfig_cluster_resolver',
'tpu_cluster_resolver',
'ClusterResolver',
'SimpleClusterResolver',
'UnionClusterResolver',
'GceClusterResolver',
'KubernetesClusterResolver',
'TFConfigClusterResolver',
'TPUClusterResolver',
'SlurmClusterResolver',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,363 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
"""Stub file for ClusterResolver to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
import six
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
# pylint: enable=unused-import
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'ClusterResolver',
'SimpleClusterResolver',
'UnionClusterResolver',
]
def format_master_url(master, rpc_layer=None):
if rpc_layer:
return '%s://%s' % (rpc_layer, master)
else:
return master
remove_undocumented(__name__, _allowed_symbols)
@six.add_metaclass(abc.ABCMeta)
class ClusterResolver(object):
"""Abstract class for all implementations of ClusterResolvers.
This defines the skeleton for all implementations of ClusterResolvers.
ClusterResolvers are a way for TensorFlow to communicate with various cluster
management systems (e.g. GCE, AWS, etc...).
By letting TensorFlow communicate with these systems, we will be able to
automatically discover and resolve IP addresses for various TensorFlow
workers. This will eventually allow us to automatically recover from
underlying machine failures and scale TensorFlow worker clusters up and down.
Note to Implementors: In addition to these abstract methods, you must also
implement the task_type, task_index, and rpc_layer attributes. You may choose
to implement them either as properties with getters or setters or directly
set the attributes.
- task_type is the name of the server's current named job (e.g. 'worker',
'ps' in a distributed parameterized training job).
- task_index is the ordinal index of the server within the task type.
- rpc_layer is the protocol used by TensorFlow to communicate with other
TensorFlow servers in a distributed environment.
"""
@abc.abstractmethod
def cluster_spec(self):
"""Retrieve the current state of the cluster and returns a ClusterSpec.
Returns:
A ClusterSpec representing the state of the cluster at the moment this
function is called.
Implementors of this function must take care in ensuring that the
ClusterSpec returned is up-to-date at the time of calling this function.
This usually means retrieving the information from the underlying cluster
management system every time this function is invoked and reconstructing
a cluster_spec, rather than attempting to cache anything.
"""
raise NotImplementedError()
@abc.abstractmethod
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Retrieves the name or URL of the session master.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
Implementors of this function must take care in ensuring that the master
returned is up-to-date at the time to calling this function. This usually
means retrieving the master every time this function is invoked.
"""
raise NotImplementedError()
@abc.abstractmethod
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of accelerator cores per worker.
This returns the number of accelerator cores (such as GPUs and TPUs)
available per worker. If workers only has CPU cores available, then this
should return 0. This method will query the master for this information
if it is not otherwise known.
Args:
session_config: (Optional) Configuration for starting a new session to
query how many accelerator cores it has.
"""
raise NotImplementedError()
@abc.abstractproperty
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
raise NotImplementedError()
class SimpleClusterResolver(ClusterResolver):
"""Simple implementation of ClusterResolver that accepts a ClusterSpec."""
def __init__(self, cluster_spec, master='', task_type=None, task_index=None,
environment='', num_accelerators_per_worker=0,
rpc_layer=None):
"""Creates a SimpleClusterResolver from a ClusterSpec."""
super(SimpleClusterResolver, self).__init__()
self._task_type = task_type
self._task_index = task_index
self._environment = environment
self._num_accelerators_per_worker = num_accelerators_per_worker
self._rpc_layer = rpc_layer
if not isinstance(cluster_spec, ClusterSpec):
raise TypeError('cluster_spec must be a ClusterSpec.')
self._cluster_spec = cluster_spec
if not isinstance(master, str):
raise TypeError('master must be a string.')
self._master = master
def cluster_spec(self):
"""Returns the ClusterSpec passed into the constructor."""
return self._cluster_spec
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC used by distributed TensorFlow.
Returns:
The name or URL of the session master.
If a task_type and task_index is given, this will override the `master`
string passed into the initialization function.
"""
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
else:
master = self._master
return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
@property
def task_type(self):
return self._task_type
@property
def task_index(self):
return self._task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._environment
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of accelerator cores per worker.
Args:
session_config: Unused. The SimpleClusterResolver does not do automatic
detection of accelerators, so a TensorFlow session will never be
created, and thus a `session_config` is never necessary here, and will
be ignored.
"""
del session_config
return self._num_accelerators_per_worker
@property
def rpc_layer(self):
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
class UnionClusterResolver(ClusterResolver):
"""Performs a union on underlying ClusterResolvers.
This class performs a union given two or more existing ClusterResolvers. It
merges the underlying ClusterResolvers, and returns one unified ClusterSpec
when cluster_spec is called. The details of the merge function is
documented in the cluster_spec function.
For additional Cluster Resolver properties such as task type, task index,
rpc layer, environment, etc..., we will return the value from the first
ClusterResolver in the union.
"""
def __init__(self, *args, **kwargs):
"""Initializes a UnionClusterResolver with other ClusterResolvers.
Args:
*args: `ClusterResolver` objects to be unionized.
**kwargs:
rpc_layer - (Optional) Override value for the RPC layer used by
TensorFlow.
task_type - (Optional) Override value for the current task type.
task_index - (Optional) Override value for the current task index.
Raises:
TypeError: If any argument is not a subclass of `ClusterResolvers`.
ValueError: If there are no arguments passed.
"""
super(UnionClusterResolver, self).__init__()
self._rpc_layer = kwargs.pop('rpc_layer', None)
self._task_type = kwargs.pop('task_type', None)
self._task_index = kwargs.pop('task_index', None)
if kwargs:
raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
if not args:
raise ValueError('At least one ClusterResolver is required.')
for cluster_resolver in args:
if not isinstance(cluster_resolver, ClusterResolver):
raise TypeError('All arguments must be a sub-class of '
'`ClusterResolver.`')
self._cluster_resolvers = args
def cluster_spec(self):
"""Returns a union of all the ClusterSpecs from the ClusterResolvers.
Returns:
A ClusterSpec containing host information merged from all the underlying
ClusterResolvers.
Raises:
KeyError: If there are conflicting keys detected when merging two or
more dictionaries, this exception is raised.
Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
same job name, we will merge the list/dict of workers.
If *all* underlying ClusterSpecs expose the set of workers as lists, we will
concatenate the lists of workers, starting with the list of workers from
the first ClusterResolver passed into the constructor.
If *any* of the ClusterSpecs expose the set of workers as a dict, we will
treat all the sets of workers as dicts (even if they are returned as lists)
and will only merge them into a dict if there is no conflicting keys. If
there is a conflicting key, we will raise a `KeyError`.
"""
merged_cluster = {}
# We figure out whether it is all lists for a particular job, or whether
# there are dicts inside.
for cluster_resolver in self._cluster_resolvers:
cluster_spec = cluster_resolver.cluster_spec()
cluster_dict = cluster_spec.as_dict()
for job_name, tasks in cluster_dict.items():
if job_name in merged_cluster:
# If we see a dict, then we write a dict out regardless.
if isinstance(tasks, dict):
merged_cluster[job_name] = {}
else:
# We take whichever type is present.
if isinstance(tasks, list):
merged_cluster[job_name] = []
else:
merged_cluster[job_name] = {}
# We then do the merge as appropriate in merged_cluster[job].
for cluster_resolver in self._cluster_resolvers:
cluster_spec = cluster_resolver.cluster_spec()
cluster_dict = cluster_spec.as_dict()
for job_name, tasks in cluster_dict.items():
if isinstance(merged_cluster[job_name], list):
# We all have lists, we can just concatenate and be done.
merged_cluster[job_name].extend(tasks)
else:
if isinstance(tasks, list):
# We convert to a dictionary if the type is a list.
task_dict = dict(zip(range(0, len(tasks)), tasks))
else:
# We can simply make a copy (for update) and be done.
task_dict = tasks.copy()
# We detect if there are duplicates, and raise an error if so.
task_keys = set(task_dict)
merged_keys = set(merged_cluster[job_name].keys())
intersected_keys = task_keys.intersection(merged_keys)
if intersected_keys:
raise KeyError('Duplicate keys detected when merging two '
'ClusterSpecs: %s' % repr(intersected_keys))
# We do the merge after all the processing.
merged_cluster[job_name].update(task_dict)
return ClusterSpec(merged_cluster)
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
This usually returns the master from the first ClusterResolver passed in,
but you can override this by specifying the task_type and task_index.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
"""
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
return format_master_url(master, rpc_layer or self._rpc_layer)
return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
@property
def task_type(self):
return self._task_type or self._cluster_resolvers[0].task_type
@property
def task_index(self):
return self._task_index or self._cluster_resolvers[0].task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._cluster_resolvers[0].environment
def num_accelerators_per_worker(self, session_config=None):
return self._cluster_resolvers[0].num_accelerators_per_worker(
session_config)
@property
def rpc_layer(self):
return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,197 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for GCE Instance Groups."""
"""Stub file for GceClusterResolver to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
# pylint: enable=unused-import
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'GceClusterResolver',
]
def _format_master_url(master, rpc_layer=None):
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
class GceClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Compute Engine.
This is an implementation of cluster resolvers for the Google Compute Engine
instance group platform. By specifying a project, zone, and instance group,
this will retrieve the IP address of all the instances within the instance
group and return a Cluster Resolver object suitable for use for distributed
TensorFlow.
"""
def __init__(self,
project,
zone,
instance_group,
port,
task_type='worker',
task_index=0,
rpc_layer='grpc',
num_accelerators_per_worker=0,
credentials='default',
service=None):
"""Creates a new GceClusterResolver object.
This takes in a few parameters and creates a GceClusterResolver project. It
will then use these parameters to query the GCE API for the IP addresses of
each instance in the instance group.
Args:
project: Name of the GCE project.
zone: Zone of the GCE instance group.
instance_group: Name of the GCE instance group.
port: Port of the listening TensorFlow server (default: 8470)
task_type: Name of the TensorFlow job this GCE instance group of VM
instances belong to.
task_index: The task index for this particular VM, within the GCE
instance group. In particular, every single instance should be assigned
a unique ordinal index within an instance group manually so that they
can be distinguished from each other.
rpc_layer: The RPC layer TensorFlow should use to communicate across
instances.
num_accelerators_per_worker: Number of accelerators (GPUs) present per
instance.
credentials: GCE Credentials. If nothing is specified, this defaults to
GoogleCredentials.get_application_default().
service: The GCE API object returned by the googleapiclient.discovery
function. (Default: discovery.build('compute', 'v1')). If you specify a
custom service object, then the credentials parameter will be ignored.
Raises:
ImportError: If the googleapiclient is not installed.
"""
self._project = project
self._zone = zone
self._instance_group = instance_group
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._port = port
self._credentials = credentials
if credentials == 'default':
if _GOOGLE_API_CLIENT_INSTALLED:
self._credentials = GoogleCredentials.get_application_default()
if service is None:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
'GCE cluster resolver')
self._service = discovery.build(
'compute', 'v1',
credentials=self._credentials)
else:
self._service = service
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest instance group info.
This returns a ClusterSpec object for use based on information from the
specified instance group. We will retrieve the information from the GCE APIs
every time this method is called.
Returns:
A ClusterSpec containing host information retrieved from GCE.
"""
request_body = {'instanceState': 'RUNNING'}
request = self._service.instanceGroups().listInstances(
project=self._project,
zone=self._zone,
instanceGroups=self._instance_group,
body=request_body,
orderBy='name')
worker_list = []
while request is not None:
response = request.execute()
items = response['items']
for instance in items:
instance_name = instance['instance'].split('/')[-1]
instance_request = self._service.instances().get(
project=self._project,
zone=self._zone,
instance=instance_name)
if instance_request is not None:
instance_details = instance_request.execute()
ip_address = instance_details['networkInterfaces'][0]['networkIP']
instance_url = '%s:%s' % (ip_address, self._port)
worker_list.append(instance_url)
request = self._service.instanceGroups().listInstances_next(
previous_request=request,
previous_response=response)
worker_list.sort()
return ClusterSpec({self._task_type: worker_list})
def master(self, task_type=None, task_index=None, rpc_layer=None):
task_type = task_type if task_type is not None else self._task_type
task_index = task_index if task_index is not None else self._task_index
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
if rpc_layer or self._rpc_layer:
return '%s://%s' % (rpc_layer or self._rpc_layer, master)
else:
return master
return ''
@property
def task_type(self):
return self._task_type
@property
def task_index(self):
return self._task_index
@task_type.setter
def task_type(self, task_type):
raise RuntimeError(
'You cannot reset the task_type of the GceClusterResolver after it has '
'been created.')
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the GCE environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
@property
def rpc_layer(self):
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators_per_worker(self, session_config=None):
del session_config # Unused, since this is set manually in __init__.
return self._num_accelerators_per_worker
remove_undocumented(__name__, _allowed_symbols)

View File

@ -12,162 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Kubernetes."""
"""Stub file for KubernetesClusterResolver for backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url
from tensorflow.python.client import device_lib
from tensorflow.python.training import server_lib
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
_KUBERNETES_API_CLIENT_INSTALLED = True
try:
from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top
from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top
except ImportError:
_KUBERNETES_API_CLIENT_INSTALLED = False
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
class KubernetesClusterResolver(ClusterResolver):
"""Cluster Resolver for Kubernetes.
_allowed_symbols = [
'KubernetesClusterResolver',
]
This is an implementation of cluster resolvers for Kubernetes. When given the
the Kubernetes namespace and label selector for pods, we will retrieve the
pod IP addresses of all running pods matching the selector, and return a
ClusterSpec based on that information.
"""
remove_undocumented(__name__, _allowed_symbols)
def __init__(self,
job_to_label_mapping=None,
tf_server_port=8470,
rpc_layer='grpc',
override_client=None):
"""Initializes a new KubernetesClusterResolver.
This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver
will attempt to talk to the Kubernetes master to retrieve all the instances
of pods matching a label selector.
Args:
job_to_label_mapping: A mapping of TensorFlow jobs to label selectors.
This allows users to specify many TensorFlow jobs in one Cluster
Resolver, and each job can have pods belong with different label
selectors. For example, a sample mapping might be
```
{'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'],
'ps': ['job-name=ps-1', 'job-name=ps-2']}
```
tf_server_port: The port the TensorFlow server is listening on.
rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate
between tasks in Kubernetes. Defaults to 'grpc'.
override_client: The Kubernetes client (usually automatically retrieved
using `from kubernetes import client as k8sclient`). If you pass this
in, you are responsible for setting Kubernetes credentials manually.
Raises:
ImportError: If the Kubernetes Python client is not installed and no
`override_client` is passed in.
RuntimeError: If autoresolve_task is not a boolean or a callable.
"""
if _KUBERNETES_API_CLIENT_INSTALLED:
k8sconfig.load_kube_config()
if not job_to_label_mapping:
job_to_label_mapping = {'worker': ['job-name=tensorflow']}
if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED:
raise ImportError('The Kubernetes Python client must be installed before'
'using the Kubernetes Cluster Resolver. To install the'
'Kubernetes Python client, run `pip install '
'kubernetes` on your command line.')
self._job_to_label_mapping = job_to_label_mapping
self._tf_server_port = tf_server_port
self._override_client = override_client
self.task_type = None
self.task_index = None
self.rpc_layer = rpc_layer
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
You must have set the task_type and task_index object properties before
calling this function, or pass in the `task_type` and `task_index`
parameters when using this function. If you do both, the function parameters
will override the object properties.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
"""
if task_type is not None and task_index is not None:
return format_master_url(
self.cluster_spec().task_address(task_type, task_index),
rpc_layer or self.rpc_layer)
if self.task_type is not None and self.task_index is not None:
return format_master_url(
self.cluster_spec().task_address(self.task_type, self.task_index),
rpc_layer or self.rpc_layer)
return ''
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest info from Kubernetes.
We retrieve the information from the Kubernetes master every time this
method is called.
Returns:
A ClusterSpec containing host information returned from Kubernetes.
Raises:
RuntimeError: If any of the pods returned by the master is not in the
`Running` phase.
"""
if not self._override_client:
k8sconfig.load_kube_config()
client = self._override_client or k8sclient.CoreV1Api()
cluster_map = {}
for tf_job in self._job_to_label_mapping:
all_pods = []
for selector in self._job_to_label_mapping[tf_job]:
ret = client.list_pod_for_all_namespaces(label_selector=selector)
selected_pods = []
# Sort the list by the name to make sure it doesn't change call to call.
for pod in sorted(ret.items, key=lambda x: x.metadata.name):
if pod.status.phase == 'Running':
selected_pods.append(
'%s:%s' % (pod.status.host_ip, self._tf_server_port))
else:
raise RuntimeError('Pod "%s" is not running; phase: "%s"' %
(pod.metadata.name, pod.status.phase))
all_pods.extend(selected_pods)
cluster_map[tf_job] = all_pods
return server_lib.ClusterSpec(cluster_map)
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the Cloud environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
def num_accelerators_per_worker(self, session_config=None):
local_devices = device_lib.list_local_devices(session_config)
return len([d for d in local_devices if d.device_type == 'GPU'])

View File

@ -12,215 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Slurm workload manager."""
"""Stub file for SlurmClusterResolver to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import subprocess
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
class SlurmClusterResolver(ClusterResolver):
"""Cluster Resolver for system with Slurm workload manager.
_allowed_symbols = [
'SlurmClusterResolver',
]
This is an implementation of cluster resolvers for Slurm clusters. This allows
the specification of jobs and task counts, number of tasks per node, number of
GPUs on each node and number of GPUs for each task, It retrieves system
attributes by Slurm environment variables, resolves allocated computing node
names, construct a cluster and return a Cluster Resolver object which an be
use for distributed TensorFlow.
"""
def _resolve_hostnames(self):
"""Resolve host names of nodes allocated in current jobs.
Returns:
A list of node names as strings.
"""
hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']).
decode('utf-8').strip().split('\n'))
return hostlist
def __init__(self,
jobs,
port_base=8888,
gpus_per_node=1,
gpus_per_task=1,
tasks_per_node=None,
auto_set_gpu=True,
rpc_layer='grpc'):
"""Creates a new SlurmClusterResolver object.
This takes in parameters and creates a SlurmClusterResolver object. It uses
those parameters to check which nodes will processes reside and resolves
their hostnames. With the number of the GPUs on each node and number of GPUs
for each task it offsets the port number for each processes and allocate
GPUs to tasks by setting environment variables. The resolver currently
supports homogeneous tasks and default Slurm process allocation.
Args:
jobs: Dictionary with job names as key and number of tasks in the job as
value
port_base: The first port number to start with for processes on a node.
gpus_per_node: Number of GPUs available on each node.
gpus_per_task: Number of GPUs to be used for each task.
tasks_per_node: Number of tasks to run on each node, if not set defaults
to Slurm's output environment variable SLURM_NTASKS_PER_NODE.
auto_set_gpu: Set the visible CUDA devices automatically while resolving
the cluster by setting CUDA_VISIBLE_DEVICES environment variable.
Defaults to True.
rpc_layer: (Optional) The protocol TensorFlow uses to communicate between
nodes. Defaults to 'grpc'.
Returns:
A ClusterResolver object which can be used with distributed TensorFlow.
Raises:
RuntimeError: If requested more GPUs per node then available or requested
more tasks then assigned tasks.
"""
# check if launched by mpirun
if 'OMPI_COMM_WORLD_RANK' in os.environ:
self._rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
self._rank = int(os.environ['SLURM_PROCID'])
num_tasks = int(os.environ['SLURM_NTASKS'])
self._jobs = collections.OrderedDict(sorted(jobs.items()))
self._port_base = port_base
# user specification overrides SLURM specification
if tasks_per_node is not None:
self._tasks_per_node = tasks_per_node
elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ:
self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
else:
raise RuntimeError('Neither `tasks_per_node` or '
'SLURM_NTASKS_PER_NODE is set.')
self._gpus_per_node = gpus_per_node
self._gpus_per_task = gpus_per_task
self._auto_set_gpu = auto_set_gpu
self.task_type = None
self.task_index = None
self.rpc_layer = rpc_layer
self._gpu_allocation = []
self._cluster_allocation = {}
if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node:
raise RuntimeError('Requested more GPUs per node then available.')
if sum(self._jobs.values()) != num_tasks:
raise RuntimeError('Requested more tasks then assigned tasks.')
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest instance group info.
This returns a ClusterSpec object for use based on information from the
specified initialization parameters and Slurm environment variables. The
cluster specification is resolved each time this function is called. The
resolver extract hostnames of nodes by scontrol and pack tasks in that
order until a node a has number of tasks that is equal to specification.
GPUs on nodes are allocated to tasks by specification through setting
CUDA_VISIBLE_DEVICES environment variable.
Returns:
A ClusterSpec containing host information retrieved from Slurm's
environment variables.
"""
hostlist = self._resolve_hostnames()
task_list = []
self._gpu_allocation = []
self._cluster_allocation = {}
for host in hostlist:
for port_offset, gpu_offset in zip(
range(self._tasks_per_node),
range(0, self._gpus_per_node, self._gpus_per_task)):
host_addr = '%s:%d' % (host, self._port_base + port_offset)
task_list.append(host_addr)
gpu_id_list = []
for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task):
gpu_id_list.append(str(gpu_id))
self._gpu_allocation.append(','.join(gpu_id_list))
cluster_rank_offset_start = 0
cluster_rank_offset_end = 0
for task_type, num_tasks in self._jobs.items():
cluster_rank_offset_end = cluster_rank_offset_start + num_tasks
self._cluster_allocation[task_type] = (
task_list[cluster_rank_offset_start:cluster_rank_offset_end])
if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end:
self.task_type = task_type
self.task_index = self._rank - cluster_rank_offset_start
cluster_rank_offset_start = cluster_rank_offset_end
if self._auto_set_gpu is True:
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
return ClusterSpec(self._cluster_allocation)
def get_task_info(self):
"""Returns job name and task_index for the process which calls this.
This returns the job name and task index for the process which calls this
function according to its rank and cluster specification. The job name and
task index are set after a cluster is constructed by cluster_spec otherwise
defaults to None.
Returns:
A string specifying job name the process belongs to and an integner
specifying the task index the process belongs to in that job.
"""
return self.task_type, self.task_index
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master string for connecting to a TensorFlow master.
Args:
task_type: (Optional) Overrides the default auto-selected task type.
task_index: (Optional) Overrides the default auto-slected task index.
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
to communicate across nodes.
Returns:
A connection string for connecting to a TensorFlow master.
"""
task_type = task_type if task_type is not None else self.task_type
task_index = task_index if task_index is not None else self.task_index
rpc_layer = rpc_layer or self.rpc_layer
master = self.cluster_spec().task_address(task_type, task_index)
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the Slurm environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
def num_accelerators_per_worker(self, session_config=None):
del session_config # Unused, since this is set in __init__ manually.
return self._gpus_per_node
remove_undocumented(__name__, _allowed_symbols)

View File

@ -12,160 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
"""Stub file for TFConfigClusterResolver to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
# pylint: enable=unused-import
_TF_CONFIG_ENV = 'TF_CONFIG'
_SESSION_MASTER_KEY = 'session_master'
_RPC_LAYER_KEY = 'rpc_layer'
_TASK_KEY = 'task'
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'TFConfigClusterResolver',
]
def format_master_url(master, rpc_layer=None):
if rpc_layer:
return '%s://%s' % (rpc_layer, master)
else:
return master
remove_undocumented(__name__, _allowed_symbols)
def _load_tf_config():
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
def _get_value_in_tfconfig(key, default=None):
tf_config = _load_tf_config()
return tf_config[key] if key in tf_config else default
class TFConfigClusterResolver(ClusterResolver):
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
def __init__(self,
task_type=None,
task_index=None,
rpc_layer=None,
environment=None,
num_accelerators_per_worker=0):
"""Creates a new TFConfigClusterResolver.
Args:
task_type: (String, optional) Overrides the task type specified in the
TF_CONFIG environment variable.
task_index: (Integer, optional) Overrides the task index specified in the
TF_CONFIG environment variable.
rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
environment: (String, optional) Overrides the environment TensorFlow
operates in.
num_accelerators_per_worker: (Integer, optional) Specifies the number of
accelerators (e.g. GPUs, TPUs, others) that each node has.
"""
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._environment = environment
self._num_accelerators_per_worker = num_accelerators_per_worker
@property
def task_type(self):
if self._task_type is None:
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
return task_info['type'] if 'type' in task_info else None
else:
return self._task_type
@property
def task_index(self):
if self._task_type is None:
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
return task_info['index'] if 'index' in task_info else None
else:
return self._task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._environment
@property
def rpc_layer(self):
if self._rpc_layer is None:
return _get_value_in_tfconfig(_RPC_LAYER_KEY)
else:
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators_per_worker(self, session_config=None):
# TODO(frankchn): Connect to server (w/ session_config) in the future.
del session_config # Unused, we do not connect to another server here.
return self._num_accelerators_per_worker
def cluster_spec(self):
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
Returns:
A ClusterSpec with information from the TF_CONFIG environment variable.
"""
tf_config = _load_tf_config()
if 'cluster' not in tf_config:
return ClusterSpec({})
return ClusterSpec(tf_config['cluster'])
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a TensorFlow session.
Args:
task_type: (String, optional) Overrides and sets the task_type of the
master.
task_index: (Integer, optional) Overrides and sets the task id of the
master.
rpc_layer: (String, optional) Overrides and sets the protocol over which
TensorFlow nodes communicate with each other.
Returns:
The address of the master.
Raises:
RuntimeError: If the task_type or task_id is not specified and the
`TF_CONFIG` environment variable does not contain a task section.
"""
# If `session_master` is set, just use that.
session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
if session_master is not None:
return session_master
# Return an empty string if we are the only job in the ClusterSpec.
cluster_spec = self.cluster_spec()
if (not cluster_spec.jobs or
(len(cluster_spec.jobs) == 1 and
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
return ''
# We try to auto-detect the task type and id, but uses the user-supplied one
# where available
task_type = task_type if task_type is not None else self.task_type
task_index = task_index if task_index is not None else self.task_index
return format_master_url(cluster_spec.task_address(task_type, task_index),
self.rpc_layer)

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,412 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Cloud TPUs."""
"""Stub file for TPUClusterResolver to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# This file (and all files in this directory in general) is a backwards
# compatibility shim that exists to re-export ClusterResolvers such that
# existing OSS code will not be broken.
from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
# pylint: disable=unused-import
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
# pylint: enable=unused-import
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util.all_util import remove_undocumented
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
_allowed_symbols = [
'TPUClusterResolver',
]
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
class TPUClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
"""
def _tpuService(self):
"""Creates a new Cloud TPU API object.
This works around an issue where the underlying HTTP connection sometimes
times out when the script has been running for too long. Other methods in
this object calls this method to get a new API object whenever they need
to communicate with the Cloud API.
Returns:
A Google Cloud TPU API object.
"""
if self._service:
return self._service
credentials = self._credentials
if credentials is None or credentials == 'default':
credentials = GoogleCredentials.get_application_default()
if self._discovery_url:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials,
discoveryServiceUrl=self._discovery_url)
else:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials)
def _requestComputeMetadata(self, path):
req = Request('http://metadata/computeMetadata/v1/%s' % path,
headers={'Metadata-Flavor': 'Google'})
resp = urlopen(req)
return compat.as_bytes(resp.read())
def _shouldResolve(self):
if isinstance(self._should_resolve_override, bool):
return self._should_resolve_override
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
@staticmethod
def _inGke():
"""When running in GKE, the environment variable will be set."""
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
def _gkeEndpoints():
return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _envVarFallback():
if _DEFAULT_ENV_VARIABLE in os.environ:
return os.environ[_DEFAULT_ENV_VARIABLE]
return None
@staticmethod
def _environmentDiscoveryUrl():
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
def __init__(self,
tpu=None,
zone=None,
project=None,
job_name='worker',
coordinator_name=None,
coordinator_address=None,
credentials='default',
service=None,
discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
for the IP addresses and ports of each Cloud TPU listed.
Args:
tpu: Either a string, or a list of strings corresponding to the TPUs to
use. If the single string is the empty string, the string 'local', or a
string that begins with 'grpc://' or '/bns', then it is assumed to not
correspond with a Cloud TPU and will instead be passed as the session
master and no ClusterSpec propagation will be done.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
project: Name of the GCP project containing Cloud TPUs. If omitted or
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
coordinator_name: The name to use for the coordinator. Set to None if the
coordinator should not be included in the computed ClusterSpec.
coordinator_address: The address of the coordinator (typically an ip:port
pair). If set to None, a TF server will be started. If coordinator_name
is None, a TF server will not be started even if coordinator_address is
None.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
discovery_url: A URL template that points to the location of
the discovery service. It should have two parameters {api} and
{apiVersion} that when filled in produce an absolute URL to the
discovery document for that service. The environment variable
'TPU_API_DISCOVERY_URL' will override this.
Raises:
ImportError: If the googleapiclient is not installed.
ValueError: If no TPUs are specified.
"""
if isinstance(tpu, list):
if not tpu:
raise ValueError('At least one TPU must be specified.')
if len(tpu) != 1:
raise NotImplementedError(
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
tpu = self._gkeEndpoints()
else:
tpu = self._envVarFallback()
if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
# By default the task_type is 'worker` and the task_index is 0 (which is the
# first worker in the task).
self.task_type = job_name
self.task_index = 0
if tpu.startswith('grpc://'):
# Cloud environment, where we are using GRPC to communicate to TPUs.
self._environment = ''
elif tpu == 'local' or not tpu:
# Google environment, where the TPU is attached to the host.
self._environment = 'google'
elif tpu.startswith('/bns'):
# Google environment, where we reach the TPU through BNS.
self._environment = 'google'
# If TPU is in the Google environment or exists locally, we don't use any
# RPC layer.
if tpu.startswith('/bns') or tpu == 'local' or not tpu:
self.rpc_layer = None
else:
self.rpc_layer = 'grpc'
# Setting this overrides the return value of self._shouldResolve()
self._should_resolve_override = None
# We strip out the protocol if it is included, and override the
# shouldResolve function to never resolve. We are adding the protocol back
# in later in self.master().
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
tpu = tpu[len(self.rpc_layer + '://'):]
self._tpu = tpu
self._should_resolve_override = False
# Whether we should actually attempt to contact Cloud APIs
should_resolve = self._shouldResolve()
# We error out if we are in a non-Cloud environment which cannot talk to the
# Cloud APIs using the standard class and a special object is not passed in.
self._service = service
if (self._service is None and should_resolve and
not _GOOGLE_API_CLIENT_INSTALLED):
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the TPU cluster resolver. Execute: '
'`pip install --upgrade google-api-python-client` '
'and `pip install --upgrade oauth2client` to '
'install with pip.')
# We save user-passed credentials, unless the user didn't pass in anything.
self._credentials = credentials
if (credentials == 'default' and should_resolve and
_GOOGLE_API_CLIENT_INSTALLED):
self._credentials = None
# Automatically detect project and zone if unspecified.
if not project and should_resolve:
project = compat.as_str(
self._requestComputeMetadata('project/project-id'))
if not zone and should_resolve:
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
zone = zone_path.split('/')[-1]
self._project = project
self._zone = zone
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
self._coordinator_name = coordinator_name
if (coordinator_name and not coordinator_address and
(should_resolve or in_gke)):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Get the Master string to be used for the session.
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
first instance in the ClusterSpec returned by the cluster_spec function.
If a non-TPU name is used when constructing a TPUClusterResolver, that will
be returned instead (e.g. If the tpus argument's value when constructing
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
'grpc://10.240.1.2:8470' will be returned).
Args:
task_type: (Optional, string) The type of the TensorFlow task of the
master.
task_index: (Optional, integer) The index of the TensorFlow task of the
master.
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
communicate with TPUs.
Returns:
string, the connection string to use when creating a session.
Raises:
ValueError: If none of the TPUs specified exists.
"""
if self._shouldResolve():
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
cluster_spec = self.cluster_spec()
if task_type is not None and task_index is not None:
# task_type and task_index is from the function parameter
master = cluster_spec.task_address(task_type, task_index)
elif self.task_type is not None and self.task_index is not None:
# task_type and task_index is from the object
master = cluster_spec.task_address(self.task_type, self.task_index)
else:
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
else:
if isinstance(self._tpu, (bytes, bytearray)):
master = self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
else:
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
return format_master_url(master, rpc_layer or self.rpc_layer)
def get_master(self):
return self.master()
def get_job_name(self):
if self._shouldResolve():
return self.task_type
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
We retrieve the information from the GCE APIs every time this method is
called.
Returns:
A ClusterSpec containing host information returned from Cloud TPUs.
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
############################################################################
# There are 5 potential cases this code must handle:
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
# a. Create a ClusterSpec that includes the coordinator job
# b. Create a ClusterSpec without the coordinator job.
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
# tasks and
# a. Create a ClusterSpec with the coordinator
# b. Create a ClusterSpec without the coordinator
# 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
############################################################################
if self._shouldResolve():
# Case 1.
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu))
service = self._tpuService()
request = service.projects().locations().nodes().get(name=full_name)
response = request.execute()
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state']))
if 'health' in response and response['health'] != 'HEALTHY':
raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
(compat.as_text(self._tpu), response['health']))
if 'networkEndpoints' in response:
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in response['networkEndpoints']
]
else:
# Fall back to the deprecated response format
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
worker_list = [instance_url]
cluster_spec = {self.task_type: worker_list}
else:
if self.rpc_layer is None:
# Case 3.
return None
# Case 2.
tpus = []
for tpu in self._tpu.split(_ENDPOINTS_SEPARATOR):
# We are working around the fact that GKE environment variable that is
# supplied to us has the protocol string embedded in it, but we want
# to strip it out for the ClusterSpec.
if (self.rpc_layer is not None and
tpu.startswith(self.rpc_layer + '://')):
tpus.append(tpu[len(self.rpc_layer + '://'):])
else:
tpus.append(tpu)
cluster_spec = {self.task_type: tpus}
if self._coordinator_address:
# {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of TPU cores per worker.
This defaults to 8 for all current TPU configurations, and we do not need
to query any remote systems for this.
Args:
session_config: Unused. Not currently necessary to query anything as this
number is 8 for all TPU configurations.
"""
del session_config # Unused. Not necessary to query anything.
return 8
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
return self._environment
def _start_local_server(self):
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
self._server = server_lib.Server(
{
'local': ['0.0.0.0:0']
}, protocol='grpc', config=None, start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))
assert len(splits) == 3, self._server.target
assert splits[0] == compat.as_bytes('grpc'), self._server.target
self._coordinator_port = compat.as_text(splits[2])
self._coordinator_address = '%s:%s' % (
address, compat.as_text(self._coordinator_port))
def __deepcopy__(self, memo):
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
return self
remove_undocumented(__name__, _allowed_symbols)

View File

@ -216,7 +216,7 @@ py_library(
],
deps = [
":tpu_lib",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/distribute",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
@ -264,7 +264,7 @@ py_library(
":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py",

View File

@ -3534,6 +3534,19 @@ py_library(
],
)
# Dependency added and used by ClusterResolvers to avoid circular dependency between keras, distribute, and training.
py_library(
name = "training_server_lib",
srcs = ["training/server_lib.py"],
srcs_version = "PY2AND3",
deps = [
":framework",
":pywrap_tensorflow",
":util",
"//tensorflow/core:protos_all_py",
],
)
py_library(
name = "saveable_object",
srcs = ["training/saveable_object.py"],

View File

@ -124,6 +124,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/ops/losses",
"//tensorflow/tools/docs:doc_controls",
],

View File

@ -0,0 +1,180 @@
# Description: Operations defined for Cluster Resolvers
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = [
"//tensorflow:__subpackages__",
],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "cluster_resolver_lib",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":base_cluster_resolver_py",
":gce_cluster_resolver_py",
":kubernetes_cluster_resolver_py",
":slurm_cluster_resolver_py",
":tfconfig_cluster_resolver_py",
":tpu_cluster_resolver_py",
"//tensorflow/python:util",
],
)
py_library(
name = "base_cluster_resolver_py",
srcs = ["cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:training_server_lib",
],
)
py_library(
name = "gce_cluster_resolver_py",
srcs = ["gce_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
],
)
py_library(
name = "tfconfig_cluster_resolver_py",
srcs = ["tfconfig_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
],
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
],
)
py_library(
name = "slurm_cluster_resolver_py",
srcs = ["slurm_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
],
)
py_library(
name = "kubernetes_cluster_resolver_py",
srcs = ["kubernetes_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
],
)
tf_py_test(
name = "base_cluster_resolver_py_test",
srcs = ["cluster_resolver_test.py"],
additional_deps = [
":base_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
main = "cluster_resolver_test.py",
)
tf_py_test(
name = "gce_cluster_resolver_py_test",
size = "small",
srcs = ["gce_cluster_resolver_test.py"],
additional_deps = [
":gce_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
main = "gce_cluster_resolver_test.py",
)
tf_py_test(
name = "tfconfig_cluster_resolver_py_test",
size = "small",
srcs = ["tfconfig_cluster_resolver_test.py"],
additional_deps = [
":tfconfig_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "tfconfig_cluster_resolver_test.py",
)
tf_py_test(
name = "tpu_cluster_resolver_py_test",
size = "small",
srcs = ["tpu_cluster_resolver_test.py"],
additional_deps = [
":tpu_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "tpu_cluster_resolver_test.py",
)
tf_py_test(
name = "slurm_cluster_resolver_py_test",
size = "small",
srcs = ["slurm_cluster_resolver_test.py"],
additional_deps = [
":slurm_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
main = "slurm_cluster_resolver_test.py",
tags = [],
)
tf_py_test(
name = "kubernetes_cluster_resolver_py_test",
size = "small",
srcs = ["kubernetes_cluster_resolver_test.py"],
additional_deps = [
":kubernetes_cluster_resolver_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
main = "kubernetes_cluster_resolver_test.py",
)

View File

@ -0,0 +1,57 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library Imports for Cluster Resolvers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.distribute.cluster_resolver import gce_cluster_resolver
from tensorflow.python.distribute.cluster_resolver import kubernetes_cluster_resolver
from tensorflow.python.distribute.cluster_resolver import slurm_cluster_resolver
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'cluster_resolver',
'gce_cluster_resolver',
'kubernetes_cluster_resolver',
'slurm_cluster_resolver',
'tfconfig_cluster_resolver',
'tpu_cluster_resolver',
'ClusterResolver',
'SimpleClusterResolver',
'UnionClusterResolver',
'GceClusterResolver',
'KubernetesClusterResolver',
'TFConfigClusterResolver',
'TPUClusterResolver',
'SlurmClusterResolver',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,374 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.training.server_lib import ClusterSpec
def format_master_url(master, rpc_layer=None):
if rpc_layer:
return '%s://%s' % (rpc_layer, master)
else:
return master
@six.add_metaclass(abc.ABCMeta)
class ClusterResolver(object):
"""Abstract class for all implementations of ClusterResolvers.
This defines the skeleton for all implementations of ClusterResolvers.
ClusterResolvers are a way for TensorFlow to communicate with various cluster
management systems (e.g. GCE, AWS, etc...).
By letting TensorFlow communicate with these systems, we will be able to
automatically discover and resolve IP addresses for various TensorFlow
workers. This will eventually allow us to automatically recover from
underlying machine failures and scale TensorFlow worker clusters up and down.
Note to Implementors: In addition to these abstract methods, you must also
implement the task_type, task_index, and rpc_layer attributes. You may choose
to implement them either as properties with getters or setters or directly
set the attributes.
- task_type is the name of the server's current named job (e.g. 'worker',
'ps' in a distributed parameterized training job).
- task_index is the ordinal index of the server within the task type.
- rpc_layer is the protocol used by TensorFlow to communicate with other
TensorFlow servers in a distributed environment.
"""
@abc.abstractmethod
def cluster_spec(self):
"""Retrieve the current state of the cluster and returns a ClusterSpec.
Returns:
A ClusterSpec representing the state of the cluster at the moment this
function is called.
Implementors of this function must take care in ensuring that the
ClusterSpec returned is up-to-date at the time of calling this function.
This usually means retrieving the information from the underlying cluster
management system every time this function is invoked and reconstructing
a cluster_spec, rather than attempting to cache anything.
"""
raise NotImplementedError()
@abc.abstractmethod
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Retrieves the name or URL of the session master.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
Implementors of this function must take care in ensuring that the master
returned is up-to-date at the time to calling this function. This usually
means retrieving the master every time this function is invoked.
"""
raise NotImplementedError()
@abc.abstractmethod
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of accelerator cores per worker.
This returns the number of accelerator cores (such as GPUs and TPUs)
available per worker. If workers only has CPU cores available, then this
should return 0. This method will query the master for this information
if it is not otherwise known.
Args:
session_config: (Optional) Configuration for starting a new session to
query how many accelerator cores it has.
"""
raise NotImplementedError()
@abc.abstractproperty
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
raise NotImplementedError()
class SimpleClusterResolver(ClusterResolver):
"""Simple implementation of ClusterResolver that accepts a ClusterSpec."""
def __init__(self, cluster_spec, master='', task_type=None, task_index=None,
environment='', num_accelerators_per_worker=0,
rpc_layer=None):
"""Creates a SimpleClusterResolver from a ClusterSpec."""
super(SimpleClusterResolver, self).__init__()
self._task_type = task_type
self._task_index = task_index
self._environment = environment
self._num_accelerators_per_worker = num_accelerators_per_worker
self._rpc_layer = rpc_layer
if not isinstance(cluster_spec, ClusterSpec):
raise TypeError('cluster_spec must be a ClusterSpec.')
self._cluster_spec = cluster_spec
if not isinstance(master, str):
raise TypeError('master must be a string.')
self._master = master
def cluster_spec(self):
"""Returns the ClusterSpec passed into the constructor."""
return self._cluster_spec
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC used by distributed TensorFlow.
Returns:
The name or URL of the session master.
If a task_type and task_index is given, this will override the `master`
string passed into the initialization function.
"""
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
else:
master = self._master
return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
@property
def task_type(self):
return self._task_type
@property
def task_index(self):
return self._task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._environment
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of accelerator cores per worker.
Args:
session_config: Unused. The SimpleClusterResolver does not do automatic
detection of accelerators, so a TensorFlow session will never be
created, and thus a `session_config` is never necessary here, and will
be ignored.
"""
del session_config
return self._num_accelerators_per_worker
@property
def rpc_layer(self):
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
class UnionClusterResolver(ClusterResolver):
"""Performs a union on underlying ClusterResolvers.
This class performs a union given two or more existing ClusterResolvers. It
merges the underlying ClusterResolvers, and returns one unified ClusterSpec
when cluster_spec is called. The details of the merge function is
documented in the cluster_spec function.
For additional Cluster Resolver properties such as task type, task index,
rpc layer, environment, etc..., we will return the value from the first
ClusterResolver in the union.
"""
def __init__(self, *args, **kwargs):
"""Initializes a UnionClusterResolver with other ClusterResolvers.
Args:
*args: `ClusterResolver` objects to be unionized.
**kwargs:
rpc_layer - (Optional) Override value for the RPC layer used by
TensorFlow.
task_type - (Optional) Override value for the current task type.
task_index - (Optional) Override value for the current task index.
Raises:
TypeError: If any argument is not a subclass of `ClusterResolvers`.
ValueError: If there are no arguments passed.
"""
super(UnionClusterResolver, self).__init__()
self._rpc_layer = kwargs.pop('rpc_layer', None)
self._task_type = kwargs.pop('task_type', None)
self._task_index = kwargs.pop('task_index', None)
if kwargs:
raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
if not args:
raise ValueError('At least one ClusterResolver is required.')
for cluster_resolver in args:
if not isinstance(cluster_resolver, ClusterResolver):
raise TypeError('All arguments must be a sub-class of '
'`ClusterResolver.`')
self._cluster_resolvers = args
def cluster_spec(self):
"""Returns a union of all the ClusterSpecs from the ClusterResolvers.
Returns:
A ClusterSpec containing host information merged from all the underlying
ClusterResolvers.
Raises:
KeyError: If there are conflicting keys detected when merging two or
more dictionaries, this exception is raised.
Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
same job name, we will merge the list/dict of workers.
If *all* underlying ClusterSpecs expose the set of workers as lists, we will
concatenate the lists of workers, starting with the list of workers from
the first ClusterResolver passed into the constructor.
If *any* of the ClusterSpecs expose the set of workers as a dict, we will
treat all the sets of workers as dicts (even if they are returned as lists)
and will only merge them into a dict if there is no conflicting keys. If
there is a conflicting key, we will raise a `KeyError`.
"""
merged_cluster = {}
# We figure out whether it is all lists for a particular job, or whether
# there are dicts inside.
for cluster_resolver in self._cluster_resolvers:
cluster_spec = cluster_resolver.cluster_spec()
cluster_dict = cluster_spec.as_dict()
for job_name, tasks in cluster_dict.items():
if job_name in merged_cluster:
# If we see a dict, then we write a dict out regardless.
if isinstance(tasks, dict):
merged_cluster[job_name] = {}
else:
# We take whichever type is present.
if isinstance(tasks, list):
merged_cluster[job_name] = []
else:
merged_cluster[job_name] = {}
# We then do the merge as appropriate in merged_cluster[job].
for cluster_resolver in self._cluster_resolvers:
cluster_spec = cluster_resolver.cluster_spec()
cluster_dict = cluster_spec.as_dict()
for job_name, tasks in cluster_dict.items():
if isinstance(merged_cluster[job_name], list):
# We all have lists, we can just concatenate and be done.
merged_cluster[job_name].extend(tasks)
else:
if isinstance(tasks, list):
# We convert to a dictionary if the type is a list.
task_dict = dict(zip(range(0, len(tasks)), tasks))
else:
# We can simply make a copy (for update) and be done.
task_dict = tasks.copy()
# We detect if there are duplicates, and raise an error if so.
task_keys = set(task_dict)
merged_keys = set(merged_cluster[job_name].keys())
intersected_keys = task_keys.intersection(merged_keys)
if intersected_keys:
raise KeyError('Duplicate keys detected when merging two '
'ClusterSpecs: %s' % repr(intersected_keys))
# We do the merge after all the processing.
merged_cluster[job_name].update(task_dict)
return ClusterSpec(merged_cluster)
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
This usually returns the master from the first ClusterResolver passed in,
but you can override this by specifying the task_type and task_index.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
"""
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
return format_master_url(master, rpc_layer or self._rpc_layer)
return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
@property
def task_type(self):
return self._task_type or self._cluster_resolvers[0].task_type
@property
def task_index(self):
return self._task_index or self._cluster_resolvers[0].task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._cluster_resolvers[0].environment
def num_accelerators_per_worker(self, session_config=None):
return self._cluster_resolvers[0].num_accelerators_per_worker(
session_config)
@property
def rpc_layer(self):
return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib

View File

@ -0,0 +1,206 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for GCE Instance Groups."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
def _format_master_url(master, rpc_layer=None):
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
class GceClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Compute Engine.
This is an implementation of cluster resolvers for the Google Compute Engine
instance group platform. By specifying a project, zone, and instance group,
this will retrieve the IP address of all the instances within the instance
group and return a Cluster Resolver object suitable for use for distributed
TensorFlow.
"""
def __init__(self,
project,
zone,
instance_group,
port,
task_type='worker',
task_index=0,
rpc_layer='grpc',
num_accelerators_per_worker=0,
credentials='default',
service=None):
"""Creates a new GceClusterResolver object.
This takes in a few parameters and creates a GceClusterResolver project. It
will then use these parameters to query the GCE API for the IP addresses of
each instance in the instance group.
Args:
project: Name of the GCE project.
zone: Zone of the GCE instance group.
instance_group: Name of the GCE instance group.
port: Port of the listening TensorFlow server (default: 8470)
task_type: Name of the TensorFlow job this GCE instance group of VM
instances belong to.
task_index: The task index for this particular VM, within the GCE
instance group. In particular, every single instance should be assigned
a unique ordinal index within an instance group manually so that they
can be distinguished from each other.
rpc_layer: The RPC layer TensorFlow should use to communicate across
instances.
num_accelerators_per_worker: Number of accelerators (GPUs) present per
instance.
credentials: GCE Credentials. If nothing is specified, this defaults to
GoogleCredentials.get_application_default().
service: The GCE API object returned by the googleapiclient.discovery
function. (Default: discovery.build('compute', 'v1')). If you specify a
custom service object, then the credentials parameter will be ignored.
Raises:
ImportError: If the googleapiclient is not installed.
"""
self._project = project
self._zone = zone
self._instance_group = instance_group
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._port = port
self._credentials = credentials
if credentials == 'default':
if _GOOGLE_API_CLIENT_INSTALLED:
self._credentials = GoogleCredentials.get_application_default()
if service is None:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
'GCE cluster resolver')
self._service = discovery.build(
'compute', 'v1',
credentials=self._credentials)
else:
self._service = service
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest instance group info.
This returns a ClusterSpec object for use based on information from the
specified instance group. We will retrieve the information from the GCE APIs
every time this method is called.
Returns:
A ClusterSpec containing host information retrieved from GCE.
"""
request_body = {'instanceState': 'RUNNING'}
request = self._service.instanceGroups().listInstances(
project=self._project,
zone=self._zone,
instanceGroups=self._instance_group,
body=request_body,
orderBy='name')
worker_list = []
while request is not None:
response = request.execute()
items = response['items']
for instance in items:
instance_name = instance['instance'].split('/')[-1]
instance_request = self._service.instances().get(
project=self._project,
zone=self._zone,
instance=instance_name)
if instance_request is not None:
instance_details = instance_request.execute()
ip_address = instance_details['networkInterfaces'][0]['networkIP']
instance_url = '%s:%s' % (ip_address, self._port)
worker_list.append(instance_url)
request = self._service.instanceGroups().listInstances_next(
previous_request=request,
previous_response=response)
worker_list.sort()
return ClusterSpec({self._task_type: worker_list})
def master(self, task_type=None, task_index=None, rpc_layer=None):
task_type = task_type if task_type is not None else self._task_type
task_index = task_index if task_index is not None else self._task_index
if task_type is not None and task_index is not None:
master = self.cluster_spec().task_address(task_type, task_index)
if rpc_layer or self._rpc_layer:
return '%s://%s' % (rpc_layer or self._rpc_layer, master)
else:
return master
return ''
@property
def task_type(self):
return self._task_type
@property
def task_index(self):
return self._task_index
@task_type.setter
def task_type(self, task_type):
raise RuntimeError(
'You cannot reset the task_type of the GceClusterResolver after it has '
'been created.')
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the GCE environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
@property
def rpc_layer(self):
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators_per_worker(self, session_config=None):
del session_config # Unused, since this is set manually in __init__.
return self._num_accelerators_per_worker

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
from tensorflow.python.distribute.cluster_resolver import GceClusterResolver
from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib

View File

@ -0,0 +1,173 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Kubernetes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import device_lib
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.training import server_lib
_KUBERNETES_API_CLIENT_INSTALLED = True
try:
from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top
from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top
except ImportError:
_KUBERNETES_API_CLIENT_INSTALLED = False
class KubernetesClusterResolver(ClusterResolver):
"""Cluster Resolver for Kubernetes.
This is an implementation of cluster resolvers for Kubernetes. When given the
the Kubernetes namespace and label selector for pods, we will retrieve the
pod IP addresses of all running pods matching the selector, and return a
ClusterSpec based on that information.
"""
def __init__(self,
job_to_label_mapping=None,
tf_server_port=8470,
rpc_layer='grpc',
override_client=None):
"""Initializes a new KubernetesClusterResolver.
This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver
will attempt to talk to the Kubernetes master to retrieve all the instances
of pods matching a label selector.
Args:
job_to_label_mapping: A mapping of TensorFlow jobs to label selectors.
This allows users to specify many TensorFlow jobs in one Cluster
Resolver, and each job can have pods belong with different label
selectors. For example, a sample mapping might be
```
{'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'],
'ps': ['job-name=ps-1', 'job-name=ps-2']}
```
tf_server_port: The port the TensorFlow server is listening on.
rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate
between tasks in Kubernetes. Defaults to 'grpc'.
override_client: The Kubernetes client (usually automatically retrieved
using `from kubernetes import client as k8sclient`). If you pass this
in, you are responsible for setting Kubernetes credentials manually.
Raises:
ImportError: If the Kubernetes Python client is not installed and no
`override_client` is passed in.
RuntimeError: If autoresolve_task is not a boolean or a callable.
"""
if _KUBERNETES_API_CLIENT_INSTALLED:
k8sconfig.load_kube_config()
if not job_to_label_mapping:
job_to_label_mapping = {'worker': ['job-name=tensorflow']}
if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED:
raise ImportError('The Kubernetes Python client must be installed before'
'using the Kubernetes Cluster Resolver. To install the'
'Kubernetes Python client, run `pip install '
'kubernetes` on your command line.')
self._job_to_label_mapping = job_to_label_mapping
self._tf_server_port = tf_server_port
self._override_client = override_client
self.task_type = None
self.task_index = None
self.rpc_layer = rpc_layer
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a session.
You must have set the task_type and task_index object properties before
calling this function, or pass in the `task_type` and `task_index`
parameters when using this function. If you do both, the function parameters
will override the object properties.
Args:
task_type: (Optional) The type of the TensorFlow task of the master.
task_index: (Optional) The index of the TensorFlow task of the master.
rpc_layer: (Optional) The RPC protocol for the given cluster.
Returns:
The name or URL of the session master.
"""
if task_type is not None and task_index is not None:
return format_master_url(
self.cluster_spec().task_address(task_type, task_index),
rpc_layer or self.rpc_layer)
if self.task_type is not None and self.task_index is not None:
return format_master_url(
self.cluster_spec().task_address(self.task_type, self.task_index),
rpc_layer or self.rpc_layer)
return ''
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest info from Kubernetes.
We retrieve the information from the Kubernetes master every time this
method is called.
Returns:
A ClusterSpec containing host information returned from Kubernetes.
Raises:
RuntimeError: If any of the pods returned by the master is not in the
`Running` phase.
"""
if not self._override_client:
k8sconfig.load_kube_config()
client = self._override_client or k8sclient.CoreV1Api()
cluster_map = {}
for tf_job in self._job_to_label_mapping:
all_pods = []
for selector in self._job_to_label_mapping[tf_job]:
ret = client.list_pod_for_all_namespaces(label_selector=selector)
selected_pods = []
# Sort the list by the name to make sure it doesn't change call to call.
for pod in sorted(ret.items, key=lambda x: x.metadata.name):
if pod.status.phase == 'Running':
selected_pods.append(
'%s:%s' % (pod.status.host_ip, self._tf_server_port))
else:
raise RuntimeError('Pod "%s" is not running; phase: "%s"' %
(pod.metadata.name, pod.status.phase))
all_pods.extend(selected_pods)
cluster_map[tf_job] = all_pods
return server_lib.ClusterSpec(cluster_map)
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the Cloud environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
def num_accelerators_per_worker(self, session_config=None):
local_devices = device_lib.list_local_devices(session_config)
return len([d for d in local_devices if d.device_type == 'GPU'])

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training import KubernetesClusterResolver
from tensorflow.python.distribute.cluster_resolver import KubernetesClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib

View File

@ -0,0 +1,226 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Slurm workload manager."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import subprocess
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
class SlurmClusterResolver(ClusterResolver):
"""Cluster Resolver for system with Slurm workload manager.
This is an implementation of cluster resolvers for Slurm clusters. This allows
the specification of jobs and task counts, number of tasks per node, number of
GPUs on each node and number of GPUs for each task, It retrieves system
attributes by Slurm environment variables, resolves allocated computing node
names, construct a cluster and return a Cluster Resolver object which an be
use for distributed TensorFlow.
"""
def _resolve_hostnames(self):
"""Resolve host names of nodes allocated in current jobs.
Returns:
A list of node names as strings.
"""
hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']).
decode('utf-8').strip().split('\n'))
return hostlist
def __init__(self,
jobs,
port_base=8888,
gpus_per_node=1,
gpus_per_task=1,
tasks_per_node=None,
auto_set_gpu=True,
rpc_layer='grpc'):
"""Creates a new SlurmClusterResolver object.
This takes in parameters and creates a SlurmClusterResolver object. It uses
those parameters to check which nodes will processes reside and resolves
their hostnames. With the number of the GPUs on each node and number of GPUs
for each task it offsets the port number for each processes and allocate
GPUs to tasks by setting environment variables. The resolver currently
supports homogeneous tasks and default Slurm process allocation.
Args:
jobs: Dictionary with job names as key and number of tasks in the job as
value
port_base: The first port number to start with for processes on a node.
gpus_per_node: Number of GPUs available on each node.
gpus_per_task: Number of GPUs to be used for each task.
tasks_per_node: Number of tasks to run on each node, if not set defaults
to Slurm's output environment variable SLURM_NTASKS_PER_NODE.
auto_set_gpu: Set the visible CUDA devices automatically while resolving
the cluster by setting CUDA_VISIBLE_DEVICES environment variable.
Defaults to True.
rpc_layer: (Optional) The protocol TensorFlow uses to communicate between
nodes. Defaults to 'grpc'.
Returns:
A ClusterResolver object which can be used with distributed TensorFlow.
Raises:
RuntimeError: If requested more GPUs per node then available or requested
more tasks then assigned tasks.
"""
# check if launched by mpirun
if 'OMPI_COMM_WORLD_RANK' in os.environ:
self._rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE'])
else:
self._rank = int(os.environ['SLURM_PROCID'])
num_tasks = int(os.environ['SLURM_NTASKS'])
self._jobs = collections.OrderedDict(sorted(jobs.items()))
self._port_base = port_base
# user specification overrides SLURM specification
if tasks_per_node is not None:
self._tasks_per_node = tasks_per_node
elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ:
self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
else:
raise RuntimeError('Neither `tasks_per_node` or '
'SLURM_NTASKS_PER_NODE is set.')
self._gpus_per_node = gpus_per_node
self._gpus_per_task = gpus_per_task
self._auto_set_gpu = auto_set_gpu
self.task_type = None
self.task_index = None
self.rpc_layer = rpc_layer
self._gpu_allocation = []
self._cluster_allocation = {}
if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node:
raise RuntimeError('Requested more GPUs per node then available.')
if sum(self._jobs.values()) != num_tasks:
raise RuntimeError('Requested more tasks then assigned tasks.')
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest instance group info.
This returns a ClusterSpec object for use based on information from the
specified initialization parameters and Slurm environment variables. The
cluster specification is resolved each time this function is called. The
resolver extract hostnames of nodes by scontrol and pack tasks in that
order until a node a has number of tasks that is equal to specification.
GPUs on nodes are allocated to tasks by specification through setting
CUDA_VISIBLE_DEVICES environment variable.
Returns:
A ClusterSpec containing host information retrieved from Slurm's
environment variables.
"""
hostlist = self._resolve_hostnames()
task_list = []
self._gpu_allocation = []
self._cluster_allocation = {}
for host in hostlist:
for port_offset, gpu_offset in zip(
range(self._tasks_per_node),
range(0, self._gpus_per_node, self._gpus_per_task)):
host_addr = '%s:%d' % (host, self._port_base + port_offset)
task_list.append(host_addr)
gpu_id_list = []
for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task):
gpu_id_list.append(str(gpu_id))
self._gpu_allocation.append(','.join(gpu_id_list))
cluster_rank_offset_start = 0
cluster_rank_offset_end = 0
for task_type, num_tasks in self._jobs.items():
cluster_rank_offset_end = cluster_rank_offset_start + num_tasks
self._cluster_allocation[task_type] = (
task_list[cluster_rank_offset_start:cluster_rank_offset_end])
if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end:
self.task_type = task_type
self.task_index = self._rank - cluster_rank_offset_start
cluster_rank_offset_start = cluster_rank_offset_end
if self._auto_set_gpu is True:
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
return ClusterSpec(self._cluster_allocation)
def get_task_info(self):
"""Returns job name and task_index for the process which calls this.
This returns the job name and task index for the process which calls this
function according to its rank and cluster specification. The job name and
task index are set after a cluster is constructed by cluster_spec otherwise
defaults to None.
Returns:
A string specifying job name the process belongs to and an integner
specifying the task index the process belongs to in that job.
"""
return self.task_type, self.task_index
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master string for connecting to a TensorFlow master.
Args:
task_type: (Optional) Overrides the default auto-selected task type.
task_index: (Optional) Overrides the default auto-slected task index.
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
to communicate across nodes.
Returns:
A connection string for connecting to a TensorFlow master.
"""
task_type = task_type if task_type is not None else self.task_type
task_index = task_index if task_index is not None else self.task_index
rpc_layer = rpc_layer or self.rpc_layer
master = self.cluster_spec().task_address(task_type, task_index)
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in.
For users in the Slurm environment, the environment property is always an
empty string, and Google users will not use this ClusterResolver for running
on internal systems.
"""
return ''
def num_accelerators_per_worker(self, session_config=None):
del session_config # Unused, since this is set in __init__ manually.
return self._gpus_per_node

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
from tensorflow.python.distribute.cluster_resolver import SlurmClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib

View File

@ -0,0 +1,171 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
_TF_CONFIG_ENV = 'TF_CONFIG'
_SESSION_MASTER_KEY = 'session_master'
_RPC_LAYER_KEY = 'rpc_layer'
_TASK_KEY = 'task'
def format_master_url(master, rpc_layer=None):
if rpc_layer:
return '%s://%s' % (rpc_layer, master)
else:
return master
def _load_tf_config():
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
def _get_value_in_tfconfig(key, default=None):
tf_config = _load_tf_config()
return tf_config[key] if key in tf_config else default
class TFConfigClusterResolver(ClusterResolver):
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
def __init__(self,
task_type=None,
task_index=None,
rpc_layer=None,
environment=None,
num_accelerators_per_worker=0):
"""Creates a new TFConfigClusterResolver.
Args:
task_type: (String, optional) Overrides the task type specified in the
TF_CONFIG environment variable.
task_index: (Integer, optional) Overrides the task index specified in the
TF_CONFIG environment variable.
rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
environment: (String, optional) Overrides the environment TensorFlow
operates in.
num_accelerators_per_worker: (Integer, optional) Specifies the number of
accelerators (e.g. GPUs, TPUs, others) that each node has.
"""
self._task_type = task_type
self._task_index = task_index
self._rpc_layer = rpc_layer
self._environment = environment
self._num_accelerators_per_worker = num_accelerators_per_worker
@property
def task_type(self):
if self._task_type is None:
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
return task_info['type'] if 'type' in task_info else None
else:
return self._task_type
@property
def task_index(self):
if self._task_type is None:
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
return task_info['index'] if 'index' in task_info else None
else:
return self._task_index
@task_type.setter
def task_type(self, task_type):
self._task_type = task_type
@task_index.setter
def task_index(self, task_index):
self._task_index = task_index
@property
def environment(self):
return self._environment
@property
def rpc_layer(self):
if self._rpc_layer is None:
return _get_value_in_tfconfig(_RPC_LAYER_KEY)
else:
return self._rpc_layer
@rpc_layer.setter
def rpc_layer(self, rpc_layer):
self._rpc_layer = rpc_layer
def num_accelerators_per_worker(self, session_config=None):
# TODO(frankchn): Connect to server (w/ session_config) in the future.
del session_config # Unused, we do not connect to another server here.
return self._num_accelerators_per_worker
def cluster_spec(self):
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
Returns:
A ClusterSpec with information from the TF_CONFIG environment variable.
"""
tf_config = _load_tf_config()
if 'cluster' not in tf_config:
return ClusterSpec({})
return ClusterSpec(tf_config['cluster'])
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Returns the master address to use when creating a TensorFlow session.
Args:
task_type: (String, optional) Overrides and sets the task_type of the
master.
task_index: (Integer, optional) Overrides and sets the task id of the
master.
rpc_layer: (String, optional) Overrides and sets the protocol over which
TensorFlow nodes communicate with each other.
Returns:
The address of the master.
Raises:
RuntimeError: If the task_type or task_id is not specified and the
`TF_CONFIG` environment variable does not contain a task section.
"""
# If `session_master` is set, just use that.
session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
if session_master is not None:
return session_master
# Return an empty string if we are the only job in the ClusterSpec.
cluster_spec = self.cluster_spec()
if (not cluster_spec.jobs or
(len(cluster_spec.jobs) == 1 and
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
return ''
# We try to auto-detect the task type and id, but uses the user-supplied one
# where available
task_type = task_type if task_type is not None else self.task_type
task_index = task_index if task_index is not None else self.task_index
return format_master_url(cluster_spec.task_address(task_type, task_index),
self.rpc_layer)

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib

View File

@ -0,0 +1,423 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Cluster Resolvers for Cloud TPUs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
class TPUClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
"""
def _tpuService(self):
"""Creates a new Cloud TPU API object.
This works around an issue where the underlying HTTP connection sometimes
times out when the script has been running for too long. Other methods in
this object calls this method to get a new API object whenever they need
to communicate with the Cloud API.
Returns:
A Google Cloud TPU API object.
"""
if self._service:
return self._service
credentials = self._credentials
if credentials is None or credentials == 'default':
credentials = GoogleCredentials.get_application_default()
if self._discovery_url:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials,
discoveryServiceUrl=self._discovery_url)
else:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials)
def _requestComputeMetadata(self, path):
req = Request('http://metadata/computeMetadata/v1/%s' % path,
headers={'Metadata-Flavor': 'Google'})
resp = urlopen(req)
return compat.as_bytes(resp.read())
def _shouldResolve(self):
if isinstance(self._should_resolve_override, bool):
return self._should_resolve_override
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
@staticmethod
def _inGke():
"""When running in GKE, the environment variable will be set."""
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
def _gkeEndpoints():
return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _envVarFallback():
if _DEFAULT_ENV_VARIABLE in os.environ:
return os.environ[_DEFAULT_ENV_VARIABLE]
return None
@staticmethod
def _environmentDiscoveryUrl():
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
def __init__(self,
tpu=None,
zone=None,
project=None,
job_name='worker',
coordinator_name=None,
coordinator_address=None,
credentials='default',
service=None,
discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
for the IP addresses and ports of each Cloud TPU listed.
Args:
tpu: Either a string, or a list of strings corresponding to the TPUs to
use. If the single string is the empty string, the string 'local', or a
string that begins with 'grpc://' or '/bns', then it is assumed to not
correspond with a Cloud TPU and will instead be passed as the session
master and no ClusterSpec propagation will be done.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
project: Name of the GCP project containing Cloud TPUs. If omitted or
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
coordinator_name: The name to use for the coordinator. Set to None if the
coordinator should not be included in the computed ClusterSpec.
coordinator_address: The address of the coordinator (typically an ip:port
pair). If set to None, a TF server will be started. If coordinator_name
is None, a TF server will not be started even if coordinator_address is
None.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
discovery_url: A URL template that points to the location of
the discovery service. It should have two parameters {api} and
{apiVersion} that when filled in produce an absolute URL to the
discovery document for that service. The environment variable
'TPU_API_DISCOVERY_URL' will override this.
Raises:
ImportError: If the googleapiclient is not installed.
ValueError: If no TPUs are specified.
"""
if isinstance(tpu, list):
if not tpu:
raise ValueError('At least one TPU must be specified.')
if len(tpu) != 1:
raise NotImplementedError(
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
tpu = self._gkeEndpoints()
else:
tpu = self._envVarFallback()
if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
# By default the task_type is 'worker` and the task_index is 0 (which is the
# first worker in the task).
self.task_type = job_name
self.task_index = 0
if tpu.startswith('grpc://'):
# Cloud environment, where we are using GRPC to communicate to TPUs.
self._environment = ''
elif tpu == 'local' or not tpu:
# Google environment, where the TPU is attached to the host.
self._environment = 'google'
elif tpu.startswith('/bns'):
# Google environment, where we reach the TPU through BNS.
self._environment = 'google'
# If TPU is in the Google environment or exists locally, we don't use any
# RPC layer.
if tpu.startswith('/bns') or tpu == 'local' or not tpu:
self.rpc_layer = None
else:
self.rpc_layer = 'grpc'
# Setting this overrides the return value of self._shouldResolve()
self._should_resolve_override = None
# We strip out the protocol if it is included, and override the
# shouldResolve function to never resolve. We are adding the protocol back
# in later in self.master().
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
tpu = tpu[len(self.rpc_layer + '://'):]
self._tpu = tpu
self._should_resolve_override = False
# Whether we should actually attempt to contact Cloud APIs
should_resolve = self._shouldResolve()
# We error out if we are in a non-Cloud environment which cannot talk to the
# Cloud APIs using the standard class and a special object is not passed in.
self._service = service
if (self._service is None and should_resolve and
not _GOOGLE_API_CLIENT_INSTALLED):
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the TPU cluster resolver. Execute: '
'`pip install --upgrade google-api-python-client` '
'and `pip install --upgrade oauth2client` to '
'install with pip.')
# We save user-passed credentials, unless the user didn't pass in anything.
self._credentials = credentials
if (credentials == 'default' and should_resolve and
_GOOGLE_API_CLIENT_INSTALLED):
self._credentials = None
# Automatically detect project and zone if unspecified.
if not project and should_resolve:
project = compat.as_str(
self._requestComputeMetadata('project/project-id'))
if not zone and should_resolve:
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
zone = zone_path.split('/')[-1]
self._project = project
self._zone = zone
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
self._coordinator_name = coordinator_name
if (coordinator_name and not coordinator_address and
(should_resolve or in_gke)):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
def master(self, task_type=None, task_index=None, rpc_layer=None):
"""Get the Master string to be used for the session.
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
first instance in the ClusterSpec returned by the cluster_spec function.
If a non-TPU name is used when constructing a TPUClusterResolver, that will
be returned instead (e.g. If the tpus argument's value when constructing
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
'grpc://10.240.1.2:8470' will be returned).
Args:
task_type: (Optional, string) The type of the TensorFlow task of the
master.
task_index: (Optional, integer) The index of the TensorFlow task of the
master.
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
communicate with TPUs.
Returns:
string, the connection string to use when creating a session.
Raises:
ValueError: If none of the TPUs specified exists.
"""
if self._shouldResolve():
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
cluster_spec = self.cluster_spec()
if task_type is not None and task_index is not None:
# task_type and task_index is from the function parameter
master = cluster_spec.task_address(task_type, task_index)
elif self.task_type is not None and self.task_index is not None:
# task_type and task_index is from the object
master = cluster_spec.task_address(self.task_type, self.task_index)
else:
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
else:
if isinstance(self._tpu, (bytes, bytearray)):
master = self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
else:
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
return format_master_url(master, rpc_layer or self.rpc_layer)
def get_master(self):
return self.master()
def get_job_name(self):
if self._shouldResolve():
return self.task_type
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
We retrieve the information from the GCE APIs every time this method is
called.
Returns:
A ClusterSpec containing host information returned from Cloud TPUs.
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
############################################################################
# There are 5 potential cases this code must handle:
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
# a. Create a ClusterSpec that includes the coordinator job
# b. Create a ClusterSpec without the coordinator job.
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
# tasks and
# a. Create a ClusterSpec with the coordinator
# b. Create a ClusterSpec without the coordinator
# 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
############################################################################
if self._shouldResolve():
# Case 1.
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu))
service = self._tpuService()
request = service.projects().locations().nodes().get(name=full_name)
response = request.execute()
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state']))
if 'health' in response and response['health'] != 'HEALTHY':
raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
(compat.as_text(self._tpu), response['health']))
if 'networkEndpoints' in response:
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in response['networkEndpoints']
]
else:
# Fall back to the deprecated response format
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
worker_list = [instance_url]
cluster_spec = {self.task_type: worker_list}
else:
if self.rpc_layer is None:
# Case 3.
return None
# Case 2.
tpus = []
for tpu in self._tpu.split(_ENDPOINTS_SEPARATOR):
# We are working around the fact that GKE environment variable that is
# supplied to us has the protocol string embedded in it, but we want
# to strip it out for the ClusterSpec.
if (self.rpc_layer is not None and
tpu.startswith(self.rpc_layer + '://')):
tpus.append(tpu[len(self.rpc_layer + '://'):])
else:
tpus.append(tpu)
cluster_spec = {self.task_type: tpus}
if self._coordinator_address:
# {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of TPU cores per worker.
This defaults to 8 for all current TPU configurations, and we do not need
to query any remote systems for this.
Args:
session_config: Unused. Not currently necessary to query anything as this
number is 8 for all TPU configurations.
"""
del session_config # Unused. Not necessary to query anything.
return 8
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
return self._environment
def _start_local_server(self):
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
self._server = server_lib.Server(
{
'local': ['0.0.0.0:0']
}, protocol='grpc', config=None, start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))
assert len(splits) == 3, self._server.target
assert splits[0] == compat.as_bytes('grpc'), self._server.target
self._coordinator_port = compat.as_text(splits[2])
self._coordinator_address = '%s:%s' % (
address, compat.as_text(self._coordinator_port))
def __deepcopy__(self, memo):
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
return self

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat