Moves ClusterResolvers into tensorflow.python.distribute in preparation for TensorFlow 2.0
PiperOrigin-RevId: 223401165
This commit is contained in:
parent
fd7b50ee62
commit
a26f3b0598
tensorflow
contrib
python
BUILD
distribute
BUILD
cluster_resolver
BUILDREADME.mdREADME.slurm__init__.pycluster_resolver.pycluster_resolver_test.pygce_cluster_resolver.pygce_cluster_resolver_test.pykubernetes_cluster_resolver.pykubernetes_cluster_resolver_test.pyslurm_cluster_resolver.pyslurm_cluster_resolver_test.pytfconfig_cluster_resolver.pytfconfig_cluster_resolver_test.pytpu_cluster_resolver.pytpu_cluster_resolver_test.py
@ -21,85 +21,18 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "cluster_resolver_py",
|
||||
srcs = [
|
||||
srcs = glob([
|
||||
"__init__.py",
|
||||
"python/training/__init__.py",
|
||||
],
|
||||
"python/training/*.py",
|
||||
]),
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
":gce_cluster_resolver_py",
|
||||
":kubernetes_cluster_resolver_py",
|
||||
":slurm_cluster_resolver_py",
|
||||
":tfconfig_cluster_resolver_py",
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_cluster_resolver_py",
|
||||
srcs = ["python/training/cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gce_cluster_resolver_py",
|
||||
srcs = ["python/training/gce_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tfconfig_cluster_resolver_py",
|
||||
srcs = ["python/training/tfconfig_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = ["python/training/tpu_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "slurm_cluster_resolver_py",
|
||||
srcs = ["python/training/slurm_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kubernetes_cluster_resolver_py",
|
||||
srcs = ["python/training/kubernetes_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
deps = ["//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "base_cluster_resolver_py_test",
|
||||
srcs = ["python/training/cluster_resolver_test.py"],
|
||||
name = "cluster_resolver_initialization_test",
|
||||
srcs = ["cluster_resolver_initialization_test.py"],
|
||||
additional_deps = [
|
||||
":cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -108,86 +41,5 @@ tf_py_test(
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
main = "python/training/cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "gce_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/gce_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":cluster_resolver_py",
|
||||
":gce_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
main = "python/training/gce_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tfconfig_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/tfconfig_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tfconfig_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "python/training/tfconfig_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/tpu_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "python/training/tpu_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "slurm_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/slurm_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":cluster_resolver_py",
|
||||
":slurm_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
main = "python/training/slurm_cluster_resolver_test.py",
|
||||
tags = [],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "kubernetes_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/kubernetes_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":cluster_resolver_py",
|
||||
":kubernetes_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
main = "python/training/kubernetes_cluster_resolver_test.py",
|
||||
main = "cluster_resolver_initialization_test.py",
|
||||
)
|
||||
|
@ -20,14 +20,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -0,0 +1,53 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests to ensure ClusterResolvers are usable via the old contrib path."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training import cluster_resolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training import UnionClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
class ClusterResolverInitializationTest(test.TestCase):
|
||||
|
||||
def testCreateSimpleClusterResolverFromLib(self):
|
||||
base_cluster_spec = server_lib.ClusterSpec({
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
})
|
||||
cluster_resolver.SimpleClusterResolver(base_cluster_spec)
|
||||
|
||||
def testCreateSimpleClusterResolver(self):
|
||||
base_cluster_spec = server_lib.ClusterSpec({
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
})
|
||||
SimpleClusterResolver(base_cluster_spec)
|
||||
|
||||
def testCreateUnionClusterResolver(self):
|
||||
base_cluster_spec = server_lib.ClusterSpec({
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
})
|
||||
simple_cr = SimpleClusterResolver(base_cluster_spec)
|
||||
UnionClusterResolver(simple_cr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -18,11 +18,36 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'cluster_resolver',
|
||||
'gce_cluster_resolver',
|
||||
'kubernetes_cluster_resolver',
|
||||
'slurm_cluster_resolver',
|
||||
'tfconfig_cluster_resolver',
|
||||
'tpu_cluster_resolver',
|
||||
'ClusterResolver',
|
||||
'SimpleClusterResolver',
|
||||
'UnionClusterResolver',
|
||||
'GceClusterResolver',
|
||||
'KubernetesClusterResolver',
|
||||
'TFConfigClusterResolver',
|
||||
'TPUClusterResolver',
|
||||
'SlurmClusterResolver',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,363 +12,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
|
||||
"""Stub file for ClusterResolver to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
import six
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'ClusterResolver',
|
||||
'SimpleClusterResolver',
|
||||
'UnionClusterResolver',
|
||||
]
|
||||
|
||||
def format_master_url(master, rpc_layer=None):
|
||||
if rpc_layer:
|
||||
return '%s://%s' % (rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class ClusterResolver(object):
|
||||
"""Abstract class for all implementations of ClusterResolvers.
|
||||
|
||||
This defines the skeleton for all implementations of ClusterResolvers.
|
||||
ClusterResolvers are a way for TensorFlow to communicate with various cluster
|
||||
management systems (e.g. GCE, AWS, etc...).
|
||||
|
||||
By letting TensorFlow communicate with these systems, we will be able to
|
||||
automatically discover and resolve IP addresses for various TensorFlow
|
||||
workers. This will eventually allow us to automatically recover from
|
||||
underlying machine failures and scale TensorFlow worker clusters up and down.
|
||||
|
||||
Note to Implementors: In addition to these abstract methods, you must also
|
||||
implement the task_type, task_index, and rpc_layer attributes. You may choose
|
||||
to implement them either as properties with getters or setters or directly
|
||||
set the attributes.
|
||||
|
||||
- task_type is the name of the server's current named job (e.g. 'worker',
|
||||
'ps' in a distributed parameterized training job).
|
||||
- task_index is the ordinal index of the server within the task type.
|
||||
- rpc_layer is the protocol used by TensorFlow to communicate with other
|
||||
TensorFlow servers in a distributed environment.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def cluster_spec(self):
|
||||
"""Retrieve the current state of the cluster and returns a ClusterSpec.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec representing the state of the cluster at the moment this
|
||||
function is called.
|
||||
|
||||
Implementors of this function must take care in ensuring that the
|
||||
ClusterSpec returned is up-to-date at the time of calling this function.
|
||||
This usually means retrieving the information from the underlying cluster
|
||||
management system every time this function is invoked and reconstructing
|
||||
a cluster_spec, rather than attempting to cache anything.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Retrieves the name or URL of the session master.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
|
||||
Implementors of this function must take care in ensuring that the master
|
||||
returned is up-to-date at the time to calling this function. This usually
|
||||
means retrieving the master every time this function is invoked.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of accelerator cores per worker.
|
||||
|
||||
This returns the number of accelerator cores (such as GPUs and TPUs)
|
||||
available per worker. If workers only has CPU cores available, then this
|
||||
should return 0. This method will query the master for this information
|
||||
if it is not otherwise known.
|
||||
|
||||
Args:
|
||||
session_config: (Optional) Configuration for starting a new session to
|
||||
query how many accelerator cores it has.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractproperty
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SimpleClusterResolver(ClusterResolver):
|
||||
"""Simple implementation of ClusterResolver that accepts a ClusterSpec."""
|
||||
|
||||
def __init__(self, cluster_spec, master='', task_type=None, task_index=None,
|
||||
environment='', num_accelerators_per_worker=0,
|
||||
rpc_layer=None):
|
||||
"""Creates a SimpleClusterResolver from a ClusterSpec."""
|
||||
super(SimpleClusterResolver, self).__init__()
|
||||
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._environment = environment
|
||||
self._num_accelerators_per_worker = num_accelerators_per_worker
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
if not isinstance(cluster_spec, ClusterSpec):
|
||||
raise TypeError('cluster_spec must be a ClusterSpec.')
|
||||
self._cluster_spec = cluster_spec
|
||||
|
||||
if not isinstance(master, str):
|
||||
raise TypeError('master must be a string.')
|
||||
self._master = master
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns the ClusterSpec passed into the constructor."""
|
||||
return self._cluster_spec
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC used by distributed TensorFlow.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
|
||||
If a task_type and task_index is given, this will override the `master`
|
||||
string passed into the initialization function.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
else:
|
||||
master = self._master
|
||||
|
||||
return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._environment
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of accelerator cores per worker.
|
||||
|
||||
Args:
|
||||
session_config: Unused. The SimpleClusterResolver does not do automatic
|
||||
detection of accelerators, so a TensorFlow session will never be
|
||||
created, and thus a `session_config` is never necessary here, and will
|
||||
be ignored.
|
||||
"""
|
||||
del session_config
|
||||
return self._num_accelerators_per_worker
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
|
||||
class UnionClusterResolver(ClusterResolver):
|
||||
"""Performs a union on underlying ClusterResolvers.
|
||||
|
||||
This class performs a union given two or more existing ClusterResolvers. It
|
||||
merges the underlying ClusterResolvers, and returns one unified ClusterSpec
|
||||
when cluster_spec is called. The details of the merge function is
|
||||
documented in the cluster_spec function.
|
||||
|
||||
For additional Cluster Resolver properties such as task type, task index,
|
||||
rpc layer, environment, etc..., we will return the value from the first
|
||||
ClusterResolver in the union.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initializes a UnionClusterResolver with other ClusterResolvers.
|
||||
|
||||
Args:
|
||||
*args: `ClusterResolver` objects to be unionized.
|
||||
**kwargs:
|
||||
rpc_layer - (Optional) Override value for the RPC layer used by
|
||||
TensorFlow.
|
||||
task_type - (Optional) Override value for the current task type.
|
||||
task_index - (Optional) Override value for the current task index.
|
||||
|
||||
Raises:
|
||||
TypeError: If any argument is not a subclass of `ClusterResolvers`.
|
||||
ValueError: If there are no arguments passed.
|
||||
"""
|
||||
super(UnionClusterResolver, self).__init__()
|
||||
|
||||
self._rpc_layer = kwargs.pop('rpc_layer', None)
|
||||
self._task_type = kwargs.pop('task_type', None)
|
||||
self._task_index = kwargs.pop('task_index', None)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
|
||||
|
||||
if not args:
|
||||
raise ValueError('At least one ClusterResolver is required.')
|
||||
|
||||
for cluster_resolver in args:
|
||||
if not isinstance(cluster_resolver, ClusterResolver):
|
||||
raise TypeError('All arguments must be a sub-class of '
|
||||
'`ClusterResolver.`')
|
||||
self._cluster_resolvers = args
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a union of all the ClusterSpecs from the ClusterResolvers.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information merged from all the underlying
|
||||
ClusterResolvers.
|
||||
|
||||
Raises:
|
||||
KeyError: If there are conflicting keys detected when merging two or
|
||||
more dictionaries, this exception is raised.
|
||||
|
||||
Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
|
||||
same job name, we will merge the list/dict of workers.
|
||||
|
||||
If *all* underlying ClusterSpecs expose the set of workers as lists, we will
|
||||
concatenate the lists of workers, starting with the list of workers from
|
||||
the first ClusterResolver passed into the constructor.
|
||||
|
||||
If *any* of the ClusterSpecs expose the set of workers as a dict, we will
|
||||
treat all the sets of workers as dicts (even if they are returned as lists)
|
||||
and will only merge them into a dict if there is no conflicting keys. If
|
||||
there is a conflicting key, we will raise a `KeyError`.
|
||||
"""
|
||||
|
||||
merged_cluster = {}
|
||||
|
||||
# We figure out whether it is all lists for a particular job, or whether
|
||||
# there are dicts inside.
|
||||
for cluster_resolver in self._cluster_resolvers:
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
cluster_dict = cluster_spec.as_dict()
|
||||
|
||||
for job_name, tasks in cluster_dict.items():
|
||||
if job_name in merged_cluster:
|
||||
# If we see a dict, then we write a dict out regardless.
|
||||
if isinstance(tasks, dict):
|
||||
merged_cluster[job_name] = {}
|
||||
else:
|
||||
# We take whichever type is present.
|
||||
if isinstance(tasks, list):
|
||||
merged_cluster[job_name] = []
|
||||
else:
|
||||
merged_cluster[job_name] = {}
|
||||
|
||||
# We then do the merge as appropriate in merged_cluster[job].
|
||||
for cluster_resolver in self._cluster_resolvers:
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
cluster_dict = cluster_spec.as_dict()
|
||||
|
||||
for job_name, tasks in cluster_dict.items():
|
||||
if isinstance(merged_cluster[job_name], list):
|
||||
# We all have lists, we can just concatenate and be done.
|
||||
merged_cluster[job_name].extend(tasks)
|
||||
else:
|
||||
if isinstance(tasks, list):
|
||||
# We convert to a dictionary if the type is a list.
|
||||
task_dict = dict(zip(range(0, len(tasks)), tasks))
|
||||
else:
|
||||
# We can simply make a copy (for update) and be done.
|
||||
task_dict = tasks.copy()
|
||||
|
||||
# We detect if there are duplicates, and raise an error if so.
|
||||
task_keys = set(task_dict)
|
||||
merged_keys = set(merged_cluster[job_name].keys())
|
||||
intersected_keys = task_keys.intersection(merged_keys)
|
||||
if intersected_keys:
|
||||
raise KeyError('Duplicate keys detected when merging two '
|
||||
'ClusterSpecs: %s' % repr(intersected_keys))
|
||||
|
||||
# We do the merge after all the processing.
|
||||
merged_cluster[job_name].update(task_dict)
|
||||
|
||||
return ClusterSpec(merged_cluster)
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
This usually returns the master from the first ClusterResolver passed in,
|
||||
but you can override this by specifying the task_type and task_index.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
return format_master_url(master, rpc_layer or self._rpc_layer)
|
||||
|
||||
return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type or self._cluster_resolvers[0].task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index or self._cluster_resolvers[0].task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._cluster_resolvers[0].environment
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
return self._cluster_resolvers[0].num_accelerators_per_worker(
|
||||
session_config)
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,197 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for GCE Instance Groups."""
|
||||
"""Stub file for GceClusterResolver to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'GceClusterResolver',
|
||||
]
|
||||
|
||||
def _format_master_url(master, rpc_layer=None):
|
||||
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
|
||||
|
||||
|
||||
class GceClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Compute Engine.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Compute Engine
|
||||
instance group platform. By specifying a project, zone, and instance group,
|
||||
this will retrieve the IP address of all the instances within the instance
|
||||
group and return a Cluster Resolver object suitable for use for distributed
|
||||
TensorFlow.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project,
|
||||
zone,
|
||||
instance_group,
|
||||
port,
|
||||
task_type='worker',
|
||||
task_index=0,
|
||||
rpc_layer='grpc',
|
||||
num_accelerators_per_worker=0,
|
||||
credentials='default',
|
||||
service=None):
|
||||
"""Creates a new GceClusterResolver object.
|
||||
|
||||
This takes in a few parameters and creates a GceClusterResolver project. It
|
||||
will then use these parameters to query the GCE API for the IP addresses of
|
||||
each instance in the instance group.
|
||||
|
||||
Args:
|
||||
project: Name of the GCE project.
|
||||
zone: Zone of the GCE instance group.
|
||||
instance_group: Name of the GCE instance group.
|
||||
port: Port of the listening TensorFlow server (default: 8470)
|
||||
task_type: Name of the TensorFlow job this GCE instance group of VM
|
||||
instances belong to.
|
||||
task_index: The task index for this particular VM, within the GCE
|
||||
instance group. In particular, every single instance should be assigned
|
||||
a unique ordinal index within an instance group manually so that they
|
||||
can be distinguished from each other.
|
||||
rpc_layer: The RPC layer TensorFlow should use to communicate across
|
||||
instances.
|
||||
num_accelerators_per_worker: Number of accelerators (GPUs) present per
|
||||
instance.
|
||||
credentials: GCE Credentials. If nothing is specified, this defaults to
|
||||
GoogleCredentials.get_application_default().
|
||||
service: The GCE API object returned by the googleapiclient.discovery
|
||||
function. (Default: discovery.build('compute', 'v1')). If you specify a
|
||||
custom service object, then the credentials parameter will be ignored.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
"""
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
self._instance_group = instance_group
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._rpc_layer = rpc_layer
|
||||
self._port = port
|
||||
self._credentials = credentials
|
||||
|
||||
if credentials == 'default':
|
||||
if _GOOGLE_API_CLIENT_INSTALLED:
|
||||
self._credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if service is None:
|
||||
if not _GOOGLE_API_CLIENT_INSTALLED:
|
||||
raise ImportError('googleapiclient must be installed before using the '
|
||||
'GCE cluster resolver')
|
||||
self._service = discovery.build(
|
||||
'compute', 'v1',
|
||||
credentials=self._credentials)
|
||||
else:
|
||||
self._service = service
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest instance group info.
|
||||
|
||||
This returns a ClusterSpec object for use based on information from the
|
||||
specified instance group. We will retrieve the information from the GCE APIs
|
||||
every time this method is called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information retrieved from GCE.
|
||||
"""
|
||||
request_body = {'instanceState': 'RUNNING'}
|
||||
request = self._service.instanceGroups().listInstances(
|
||||
project=self._project,
|
||||
zone=self._zone,
|
||||
instanceGroups=self._instance_group,
|
||||
body=request_body,
|
||||
orderBy='name')
|
||||
|
||||
worker_list = []
|
||||
|
||||
while request is not None:
|
||||
response = request.execute()
|
||||
|
||||
items = response['items']
|
||||
for instance in items:
|
||||
instance_name = instance['instance'].split('/')[-1]
|
||||
|
||||
instance_request = self._service.instances().get(
|
||||
project=self._project,
|
||||
zone=self._zone,
|
||||
instance=instance_name)
|
||||
|
||||
if instance_request is not None:
|
||||
instance_details = instance_request.execute()
|
||||
ip_address = instance_details['networkInterfaces'][0]['networkIP']
|
||||
instance_url = '%s:%s' % (ip_address, self._port)
|
||||
worker_list.append(instance_url)
|
||||
|
||||
request = self._service.instanceGroups().listInstances_next(
|
||||
previous_request=request,
|
||||
previous_response=response)
|
||||
|
||||
worker_list.sort()
|
||||
return ClusterSpec({self._task_type: worker_list})
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
task_type = task_type if task_type is not None else self._task_type
|
||||
task_index = task_index if task_index is not None else self._task_index
|
||||
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
if rpc_layer or self._rpc_layer:
|
||||
return '%s://%s' % (rpc_layer or self._rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
|
||||
return ''
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
raise RuntimeError(
|
||||
'You cannot reset the task_type of the GceClusterResolver after it has '
|
||||
'been created.')
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the GCE environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
del session_config # Unused, since this is set manually in __init__.
|
||||
return self._num_accelerators_per_worker
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -12,162 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Kubernetes."""
|
||||
"""Stub file for KubernetesClusterResolver for backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.training import server_lib
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
_KUBERNETES_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top
|
||||
from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_KUBERNETES_API_CLIENT_INSTALLED = False
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
class KubernetesClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Kubernetes.
|
||||
_allowed_symbols = [
|
||||
'KubernetesClusterResolver',
|
||||
]
|
||||
|
||||
This is an implementation of cluster resolvers for Kubernetes. When given the
|
||||
the Kubernetes namespace and label selector for pods, we will retrieve the
|
||||
pod IP addresses of all running pods matching the selector, and return a
|
||||
ClusterSpec based on that information.
|
||||
"""
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
||||
def __init__(self,
|
||||
job_to_label_mapping=None,
|
||||
tf_server_port=8470,
|
||||
rpc_layer='grpc',
|
||||
override_client=None):
|
||||
"""Initializes a new KubernetesClusterResolver.
|
||||
|
||||
This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver
|
||||
will attempt to talk to the Kubernetes master to retrieve all the instances
|
||||
of pods matching a label selector.
|
||||
|
||||
Args:
|
||||
job_to_label_mapping: A mapping of TensorFlow jobs to label selectors.
|
||||
This allows users to specify many TensorFlow jobs in one Cluster
|
||||
Resolver, and each job can have pods belong with different label
|
||||
selectors. For example, a sample mapping might be
|
||||
```
|
||||
{'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'],
|
||||
'ps': ['job-name=ps-1', 'job-name=ps-2']}
|
||||
```
|
||||
tf_server_port: The port the TensorFlow server is listening on.
|
||||
rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate
|
||||
between tasks in Kubernetes. Defaults to 'grpc'.
|
||||
override_client: The Kubernetes client (usually automatically retrieved
|
||||
using `from kubernetes import client as k8sclient`). If you pass this
|
||||
in, you are responsible for setting Kubernetes credentials manually.
|
||||
|
||||
Raises:
|
||||
ImportError: If the Kubernetes Python client is not installed and no
|
||||
`override_client` is passed in.
|
||||
RuntimeError: If autoresolve_task is not a boolean or a callable.
|
||||
"""
|
||||
if _KUBERNETES_API_CLIENT_INSTALLED:
|
||||
k8sconfig.load_kube_config()
|
||||
|
||||
if not job_to_label_mapping:
|
||||
job_to_label_mapping = {'worker': ['job-name=tensorflow']}
|
||||
|
||||
if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED:
|
||||
raise ImportError('The Kubernetes Python client must be installed before'
|
||||
'using the Kubernetes Cluster Resolver. To install the'
|
||||
'Kubernetes Python client, run `pip install '
|
||||
'kubernetes` on your command line.')
|
||||
|
||||
self._job_to_label_mapping = job_to_label_mapping
|
||||
self._tf_server_port = tf_server_port
|
||||
self._override_client = override_client
|
||||
|
||||
self.task_type = None
|
||||
self.task_index = None
|
||||
self.rpc_layer = rpc_layer
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
You must have set the task_type and task_index object properties before
|
||||
calling this function, or pass in the `task_type` and `task_index`
|
||||
parameters when using this function. If you do both, the function parameters
|
||||
will override the object properties.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
return format_master_url(
|
||||
self.cluster_spec().task_address(task_type, task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
if self.task_type is not None and self.task_index is not None:
|
||||
return format_master_url(
|
||||
self.cluster_spec().task_address(self.task_type, self.task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
return ''
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest info from Kubernetes.
|
||||
|
||||
We retrieve the information from the Kubernetes master every time this
|
||||
method is called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information returned from Kubernetes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the pods returned by the master is not in the
|
||||
`Running` phase.
|
||||
"""
|
||||
if not self._override_client:
|
||||
k8sconfig.load_kube_config()
|
||||
|
||||
client = self._override_client or k8sclient.CoreV1Api()
|
||||
cluster_map = {}
|
||||
|
||||
for tf_job in self._job_to_label_mapping:
|
||||
all_pods = []
|
||||
for selector in self._job_to_label_mapping[tf_job]:
|
||||
ret = client.list_pod_for_all_namespaces(label_selector=selector)
|
||||
selected_pods = []
|
||||
|
||||
# Sort the list by the name to make sure it doesn't change call to call.
|
||||
for pod in sorted(ret.items, key=lambda x: x.metadata.name):
|
||||
if pod.status.phase == 'Running':
|
||||
selected_pods.append(
|
||||
'%s:%s' % (pod.status.host_ip, self._tf_server_port))
|
||||
else:
|
||||
raise RuntimeError('Pod "%s" is not running; phase: "%s"' %
|
||||
(pod.metadata.name, pod.status.phase))
|
||||
all_pods.extend(selected_pods)
|
||||
cluster_map[tf_job] = all_pods
|
||||
|
||||
return server_lib.ClusterSpec(cluster_map)
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the Cloud environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
local_devices = device_lib.list_local_devices(session_config)
|
||||
return len([d for d in local_devices if d.device_type == 'GPU'])
|
||||
|
@ -12,215 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Slurm workload manager."""
|
||||
"""Stub file for SlurmClusterResolver to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import subprocess
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
class SlurmClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for system with Slurm workload manager.
|
||||
_allowed_symbols = [
|
||||
'SlurmClusterResolver',
|
||||
]
|
||||
|
||||
This is an implementation of cluster resolvers for Slurm clusters. This allows
|
||||
the specification of jobs and task counts, number of tasks per node, number of
|
||||
GPUs on each node and number of GPUs for each task, It retrieves system
|
||||
attributes by Slurm environment variables, resolves allocated computing node
|
||||
names, construct a cluster and return a Cluster Resolver object which an be
|
||||
use for distributed TensorFlow.
|
||||
"""
|
||||
|
||||
def _resolve_hostnames(self):
|
||||
"""Resolve host names of nodes allocated in current jobs.
|
||||
|
||||
Returns:
|
||||
A list of node names as strings.
|
||||
"""
|
||||
hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']).
|
||||
decode('utf-8').strip().split('\n'))
|
||||
return hostlist
|
||||
|
||||
def __init__(self,
|
||||
jobs,
|
||||
port_base=8888,
|
||||
gpus_per_node=1,
|
||||
gpus_per_task=1,
|
||||
tasks_per_node=None,
|
||||
auto_set_gpu=True,
|
||||
rpc_layer='grpc'):
|
||||
"""Creates a new SlurmClusterResolver object.
|
||||
|
||||
This takes in parameters and creates a SlurmClusterResolver object. It uses
|
||||
those parameters to check which nodes will processes reside and resolves
|
||||
their hostnames. With the number of the GPUs on each node and number of GPUs
|
||||
for each task it offsets the port number for each processes and allocate
|
||||
GPUs to tasks by setting environment variables. The resolver currently
|
||||
supports homogeneous tasks and default Slurm process allocation.
|
||||
|
||||
Args:
|
||||
jobs: Dictionary with job names as key and number of tasks in the job as
|
||||
value
|
||||
port_base: The first port number to start with for processes on a node.
|
||||
gpus_per_node: Number of GPUs available on each node.
|
||||
gpus_per_task: Number of GPUs to be used for each task.
|
||||
tasks_per_node: Number of tasks to run on each node, if not set defaults
|
||||
to Slurm's output environment variable SLURM_NTASKS_PER_NODE.
|
||||
auto_set_gpu: Set the visible CUDA devices automatically while resolving
|
||||
the cluster by setting CUDA_VISIBLE_DEVICES environment variable.
|
||||
Defaults to True.
|
||||
rpc_layer: (Optional) The protocol TensorFlow uses to communicate between
|
||||
nodes. Defaults to 'grpc'.
|
||||
|
||||
Returns:
|
||||
A ClusterResolver object which can be used with distributed TensorFlow.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If requested more GPUs per node then available or requested
|
||||
more tasks then assigned tasks.
|
||||
"""
|
||||
|
||||
# check if launched by mpirun
|
||||
if 'OMPI_COMM_WORLD_RANK' in os.environ:
|
||||
self._rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
else:
|
||||
self._rank = int(os.environ['SLURM_PROCID'])
|
||||
num_tasks = int(os.environ['SLURM_NTASKS'])
|
||||
|
||||
self._jobs = collections.OrderedDict(sorted(jobs.items()))
|
||||
self._port_base = port_base
|
||||
|
||||
# user specification overrides SLURM specification
|
||||
if tasks_per_node is not None:
|
||||
self._tasks_per_node = tasks_per_node
|
||||
elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ:
|
||||
self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
|
||||
else:
|
||||
raise RuntimeError('Neither `tasks_per_node` or '
|
||||
'SLURM_NTASKS_PER_NODE is set.')
|
||||
|
||||
self._gpus_per_node = gpus_per_node
|
||||
self._gpus_per_task = gpus_per_task
|
||||
|
||||
self._auto_set_gpu = auto_set_gpu
|
||||
self.task_type = None
|
||||
self.task_index = None
|
||||
self.rpc_layer = rpc_layer
|
||||
|
||||
self._gpu_allocation = []
|
||||
self._cluster_allocation = {}
|
||||
|
||||
if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node:
|
||||
raise RuntimeError('Requested more GPUs per node then available.')
|
||||
|
||||
if sum(self._jobs.values()) != num_tasks:
|
||||
raise RuntimeError('Requested more tasks then assigned tasks.')
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest instance group info.
|
||||
|
||||
This returns a ClusterSpec object for use based on information from the
|
||||
specified initialization parameters and Slurm environment variables. The
|
||||
cluster specification is resolved each time this function is called. The
|
||||
resolver extract hostnames of nodes by scontrol and pack tasks in that
|
||||
order until a node a has number of tasks that is equal to specification.
|
||||
GPUs on nodes are allocated to tasks by specification through setting
|
||||
CUDA_VISIBLE_DEVICES environment variable.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information retrieved from Slurm's
|
||||
environment variables.
|
||||
"""
|
||||
hostlist = self._resolve_hostnames()
|
||||
|
||||
task_list = []
|
||||
self._gpu_allocation = []
|
||||
self._cluster_allocation = {}
|
||||
|
||||
for host in hostlist:
|
||||
for port_offset, gpu_offset in zip(
|
||||
range(self._tasks_per_node),
|
||||
range(0, self._gpus_per_node, self._gpus_per_task)):
|
||||
|
||||
host_addr = '%s:%d' % (host, self._port_base + port_offset)
|
||||
task_list.append(host_addr)
|
||||
gpu_id_list = []
|
||||
|
||||
for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task):
|
||||
gpu_id_list.append(str(gpu_id))
|
||||
|
||||
self._gpu_allocation.append(','.join(gpu_id_list))
|
||||
|
||||
cluster_rank_offset_start = 0
|
||||
cluster_rank_offset_end = 0
|
||||
|
||||
for task_type, num_tasks in self._jobs.items():
|
||||
cluster_rank_offset_end = cluster_rank_offset_start + num_tasks
|
||||
|
||||
self._cluster_allocation[task_type] = (
|
||||
task_list[cluster_rank_offset_start:cluster_rank_offset_end])
|
||||
|
||||
if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end:
|
||||
self.task_type = task_type
|
||||
self.task_index = self._rank - cluster_rank_offset_start
|
||||
|
||||
cluster_rank_offset_start = cluster_rank_offset_end
|
||||
|
||||
if self._auto_set_gpu is True:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
|
||||
|
||||
return ClusterSpec(self._cluster_allocation)
|
||||
|
||||
def get_task_info(self):
|
||||
"""Returns job name and task_index for the process which calls this.
|
||||
|
||||
This returns the job name and task index for the process which calls this
|
||||
function according to its rank and cluster specification. The job name and
|
||||
task index are set after a cluster is constructed by cluster_spec otherwise
|
||||
defaults to None.
|
||||
|
||||
Returns:
|
||||
A string specifying job name the process belongs to and an integner
|
||||
specifying the task index the process belongs to in that job.
|
||||
"""
|
||||
return self.task_type, self.task_index
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master string for connecting to a TensorFlow master.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) Overrides the default auto-selected task type.
|
||||
task_index: (Optional) Overrides the default auto-slected task index.
|
||||
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
|
||||
to communicate across nodes.
|
||||
|
||||
Returns:
|
||||
A connection string for connecting to a TensorFlow master.
|
||||
"""
|
||||
task_type = task_type if task_type is not None else self.task_type
|
||||
task_index = task_index if task_index is not None else self.task_index
|
||||
rpc_layer = rpc_layer or self.rpc_layer
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
|
||||
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the Slurm environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
del session_config # Unused, since this is set in __init__ manually.
|
||||
return self._gpus_per_node
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -12,160 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
|
||||
|
||||
"""Stub file for TFConfigClusterResolver to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
_TF_CONFIG_ENV = 'TF_CONFIG'
|
||||
_SESSION_MASTER_KEY = 'session_master'
|
||||
_RPC_LAYER_KEY = 'rpc_layer'
|
||||
_TASK_KEY = 'task'
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'TFConfigClusterResolver',
|
||||
]
|
||||
|
||||
def format_master_url(master, rpc_layer=None):
|
||||
if rpc_layer:
|
||||
return '%s://%s' % (rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
||||
|
||||
def _load_tf_config():
|
||||
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
|
||||
|
||||
def _get_value_in_tfconfig(key, default=None):
|
||||
tf_config = _load_tf_config()
|
||||
return tf_config[key] if key in tf_config else default
|
||||
|
||||
|
||||
class TFConfigClusterResolver(ClusterResolver):
|
||||
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
|
||||
|
||||
def __init__(self,
|
||||
task_type=None,
|
||||
task_index=None,
|
||||
rpc_layer=None,
|
||||
environment=None,
|
||||
num_accelerators_per_worker=0):
|
||||
"""Creates a new TFConfigClusterResolver.
|
||||
|
||||
Args:
|
||||
task_type: (String, optional) Overrides the task type specified in the
|
||||
TF_CONFIG environment variable.
|
||||
task_index: (Integer, optional) Overrides the task index specified in the
|
||||
TF_CONFIG environment variable.
|
||||
rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
|
||||
environment: (String, optional) Overrides the environment TensorFlow
|
||||
operates in.
|
||||
num_accelerators_per_worker: (Integer, optional) Specifies the number of
|
||||
accelerators (e.g. GPUs, TPUs, others) that each node has.
|
||||
"""
|
||||
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._rpc_layer = rpc_layer
|
||||
self._environment = environment
|
||||
self._num_accelerators_per_worker = num_accelerators_per_worker
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
if self._task_type is None:
|
||||
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
|
||||
return task_info['type'] if 'type' in task_info else None
|
||||
else:
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
if self._task_type is None:
|
||||
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
|
||||
return task_info['index'] if 'index' in task_info else None
|
||||
else:
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._environment
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
if self._rpc_layer is None:
|
||||
return _get_value_in_tfconfig(_RPC_LAYER_KEY)
|
||||
else:
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
# TODO(frankchn): Connect to server (w/ session_config) in the future.
|
||||
del session_config # Unused, we do not connect to another server here.
|
||||
return self._num_accelerators_per_worker
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec with information from the TF_CONFIG environment variable.
|
||||
"""
|
||||
tf_config = _load_tf_config()
|
||||
if 'cluster' not in tf_config:
|
||||
return ClusterSpec({})
|
||||
return ClusterSpec(tf_config['cluster'])
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a TensorFlow session.
|
||||
|
||||
Args:
|
||||
task_type: (String, optional) Overrides and sets the task_type of the
|
||||
master.
|
||||
task_index: (Integer, optional) Overrides and sets the task id of the
|
||||
master.
|
||||
rpc_layer: (String, optional) Overrides and sets the protocol over which
|
||||
TensorFlow nodes communicate with each other.
|
||||
|
||||
Returns:
|
||||
The address of the master.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the task_type or task_id is not specified and the
|
||||
`TF_CONFIG` environment variable does not contain a task section.
|
||||
"""
|
||||
|
||||
# If `session_master` is set, just use that.
|
||||
session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
|
||||
if session_master is not None:
|
||||
return session_master
|
||||
|
||||
# Return an empty string if we are the only job in the ClusterSpec.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if (not cluster_spec.jobs or
|
||||
(len(cluster_spec.jobs) == 1 and
|
||||
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
|
||||
return ''
|
||||
|
||||
# We try to auto-detect the task type and id, but uses the user-supplied one
|
||||
# where available
|
||||
task_type = task_type if task_type is not None else self.task_type
|
||||
task_index = task_index if task_index is not None else self.task_index
|
||||
|
||||
return format_master_url(cluster_spec.task_address(task_type, task_index),
|
||||
self.rpc_layer)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,412 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Cloud TPUs."""
|
||||
"""Stub file for TPUClusterResolver to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
# This file (and all files in this directory in general) is a backwards
|
||||
# compatibility shim that exists to re-export ClusterResolvers such that
|
||||
# existing OSS code will not be broken.
|
||||
|
||||
from six.moves.urllib.request import Request
|
||||
from six.moves.urllib.request import urlopen
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
_allowed_symbols = [
|
||||
'TPUClusterResolver',
|
||||
]
|
||||
|
||||
|
||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||
_ENDPOINTS_SEPARATOR = ','
|
||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
|
||||
|
||||
class TPUClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||
service. As Cloud TPUs are in alpha, you will need to specify a API definition
|
||||
file for this to consume, in addition to a list of Cloud TPUs in your Google
|
||||
Cloud Platform project.
|
||||
"""
|
||||
|
||||
def _tpuService(self):
|
||||
"""Creates a new Cloud TPU API object.
|
||||
|
||||
This works around an issue where the underlying HTTP connection sometimes
|
||||
times out when the script has been running for too long. Other methods in
|
||||
this object calls this method to get a new API object whenever they need
|
||||
to communicate with the Cloud API.
|
||||
|
||||
Returns:
|
||||
A Google Cloud TPU API object.
|
||||
"""
|
||||
if self._service:
|
||||
return self._service
|
||||
|
||||
credentials = self._credentials
|
||||
if credentials is None or credentials == 'default':
|
||||
credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if self._discovery_url:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials,
|
||||
discoveryServiceUrl=self._discovery_url)
|
||||
else:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials)
|
||||
|
||||
def _requestComputeMetadata(self, path):
|
||||
req = Request('http://metadata/computeMetadata/v1/%s' % path,
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
|
||||
def _shouldResolve(self):
|
||||
if isinstance(self._should_resolve_override, bool):
|
||||
return self._should_resolve_override
|
||||
if (self._tpu == compat.as_bytes('') or
|
||||
self._tpu == compat.as_bytes('local') or
|
||||
self._tpu.startswith(compat.as_bytes('/bns')) or
|
||||
self._tpu.startswith(compat.as_bytes('localhost:')) or
|
||||
self._tpu.startswith(compat.as_bytes('grpc://'))):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _inGke():
|
||||
"""When running in GKE, the environment variable will be set."""
|
||||
return _GKE_ENV_VARIABLE in os.environ
|
||||
|
||||
@staticmethod
|
||||
def _gkeEndpoints():
|
||||
return os.environ[_GKE_ENV_VARIABLE]
|
||||
|
||||
@staticmethod
|
||||
def _envVarFallback():
|
||||
if _DEFAULT_ENV_VARIABLE in os.environ:
|
||||
return os.environ[_DEFAULT_ENV_VARIABLE]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _environmentDiscoveryUrl():
|
||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||
|
||||
def __init__(self,
|
||||
tpu=None,
|
||||
zone=None,
|
||||
project=None,
|
||||
job_name='worker',
|
||||
coordinator_name=None,
|
||||
coordinator_address=None,
|
||||
credentials='default',
|
||||
service=None,
|
||||
discovery_url=None):
|
||||
"""Creates a new TPUClusterResolver object.
|
||||
|
||||
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
|
||||
for the IP addresses and ports of each Cloud TPU listed.
|
||||
|
||||
Args:
|
||||
tpu: Either a string, or a list of strings corresponding to the TPUs to
|
||||
use. If the single string is the empty string, the string 'local', or a
|
||||
string that begins with 'grpc://' or '/bns', then it is assumed to not
|
||||
correspond with a Cloud TPU and will instead be passed as the session
|
||||
master and no ClusterSpec propagation will be done.
|
||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||
will try to discover from the GCE metadata service.
|
||||
project: Name of the GCP project containing Cloud TPUs. If omitted or
|
||||
empty, we will try to discover the project name of the GCE VM from the
|
||||
GCE metadata service.
|
||||
job_name: Name of the TensorFlow job the TPUs belong to.
|
||||
coordinator_name: The name to use for the coordinator. Set to None if the
|
||||
coordinator should not be included in the computed ClusterSpec.
|
||||
coordinator_address: The address of the coordinator (typically an ip:port
|
||||
pair). If set to None, a TF server will be started. If coordinator_name
|
||||
is None, a TF server will not be started even if coordinator_address is
|
||||
None.
|
||||
credentials: GCE Credentials. If None, then we use default credentials
|
||||
from the oauth2client
|
||||
service: The GCE API object returned by the googleapiclient.discovery
|
||||
function. If you specify a custom service object, then the credentials
|
||||
parameter will be ignored.
|
||||
discovery_url: A URL template that points to the location of
|
||||
the discovery service. It should have two parameters {api} and
|
||||
{apiVersion} that when filled in produce an absolute URL to the
|
||||
discovery document for that service. The environment variable
|
||||
'TPU_API_DISCOVERY_URL' will override this.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
ValueError: If no TPUs are specified.
|
||||
"""
|
||||
if isinstance(tpu, list):
|
||||
if not tpu:
|
||||
raise ValueError('At least one TPU must be specified.')
|
||||
if len(tpu) != 1:
|
||||
raise NotImplementedError(
|
||||
'Using multiple TPUs in a single session is not yet implemented')
|
||||
tpu = tpu[0]
|
||||
|
||||
in_gke = self._inGke()
|
||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
||||
if tpu is None:
|
||||
if in_gke:
|
||||
tpu = self._gkeEndpoints()
|
||||
else:
|
||||
tpu = self._envVarFallback()
|
||||
|
||||
if tpu is None:
|
||||
raise ValueError('Please provide a TPU Name to connect to.')
|
||||
|
||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||
|
||||
# By default the task_type is 'worker` and the task_index is 0 (which is the
|
||||
# first worker in the task).
|
||||
self.task_type = job_name
|
||||
self.task_index = 0
|
||||
|
||||
if tpu.startswith('grpc://'):
|
||||
# Cloud environment, where we are using GRPC to communicate to TPUs.
|
||||
self._environment = ''
|
||||
elif tpu == 'local' or not tpu:
|
||||
# Google environment, where the TPU is attached to the host.
|
||||
self._environment = 'google'
|
||||
elif tpu.startswith('/bns'):
|
||||
# Google environment, where we reach the TPU through BNS.
|
||||
self._environment = 'google'
|
||||
|
||||
# If TPU is in the Google environment or exists locally, we don't use any
|
||||
# RPC layer.
|
||||
if tpu.startswith('/bns') or tpu == 'local' or not tpu:
|
||||
self.rpc_layer = None
|
||||
else:
|
||||
self.rpc_layer = 'grpc'
|
||||
|
||||
# Setting this overrides the return value of self._shouldResolve()
|
||||
self._should_resolve_override = None
|
||||
|
||||
# We strip out the protocol if it is included, and override the
|
||||
# shouldResolve function to never resolve. We are adding the protocol back
|
||||
# in later in self.master().
|
||||
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
|
||||
tpu = tpu[len(self.rpc_layer + '://'):]
|
||||
self._tpu = tpu
|
||||
self._should_resolve_override = False
|
||||
|
||||
# Whether we should actually attempt to contact Cloud APIs
|
||||
should_resolve = self._shouldResolve()
|
||||
|
||||
# We error out if we are in a non-Cloud environment which cannot talk to the
|
||||
# Cloud APIs using the standard class and a special object is not passed in.
|
||||
self._service = service
|
||||
if (self._service is None and should_resolve and
|
||||
not _GOOGLE_API_CLIENT_INSTALLED):
|
||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||
'before using the TPU cluster resolver. Execute: '
|
||||
'`pip install --upgrade google-api-python-client` '
|
||||
'and `pip install --upgrade oauth2client` to '
|
||||
'install with pip.')
|
||||
|
||||
# We save user-passed credentials, unless the user didn't pass in anything.
|
||||
self._credentials = credentials
|
||||
if (credentials == 'default' and should_resolve and
|
||||
_GOOGLE_API_CLIENT_INSTALLED):
|
||||
self._credentials = None
|
||||
|
||||
# Automatically detect project and zone if unspecified.
|
||||
if not project and should_resolve:
|
||||
project = compat.as_str(
|
||||
self._requestComputeMetadata('project/project-id'))
|
||||
if not zone and should_resolve:
|
||||
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
|
||||
zone = zone_path.split('/')[-1]
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
|
||||
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
|
||||
|
||||
self._coordinator_name = coordinator_name
|
||||
if (coordinator_name and not coordinator_address and
|
||||
(should_resolve or in_gke)):
|
||||
self._start_local_server()
|
||||
else:
|
||||
self._coordinator_address = coordinator_address
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Get the Master string to be used for the session.
|
||||
|
||||
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
|
||||
first instance in the ClusterSpec returned by the cluster_spec function.
|
||||
|
||||
If a non-TPU name is used when constructing a TPUClusterResolver, that will
|
||||
be returned instead (e.g. If the tpus argument's value when constructing
|
||||
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
|
||||
'grpc://10.240.1.2:8470' will be returned).
|
||||
|
||||
Args:
|
||||
task_type: (Optional, string) The type of the TensorFlow task of the
|
||||
master.
|
||||
task_index: (Optional, integer) The index of the TensorFlow task of the
|
||||
master.
|
||||
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
|
||||
communicate with TPUs.
|
||||
|
||||
Returns:
|
||||
string, the connection string to use when creating a session.
|
||||
|
||||
Raises:
|
||||
ValueError: If none of the TPUs specified exists.
|
||||
"""
|
||||
if self._shouldResolve():
|
||||
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if task_type is not None and task_index is not None:
|
||||
# task_type and task_index is from the function parameter
|
||||
master = cluster_spec.task_address(task_type, task_index)
|
||||
elif self.task_type is not None and self.task_index is not None:
|
||||
# task_type and task_index is from the object
|
||||
master = cluster_spec.task_address(self.task_type, self.task_index)
|
||||
else:
|
||||
# by default we take the first item in the cluster with the right name
|
||||
job_tasks = cluster_spec.job_tasks(self.task_type)
|
||||
if not job_tasks:
|
||||
raise ValueError('No TPUs with the specified names exist.')
|
||||
master = job_tasks[0]
|
||||
else:
|
||||
if isinstance(self._tpu, (bytes, bytearray)):
|
||||
master = self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
|
||||
else:
|
||||
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
|
||||
return format_master_url(master, rpc_layer or self.rpc_layer)
|
||||
|
||||
def get_master(self):
|
||||
return self.master()
|
||||
|
||||
def get_job_name(self):
|
||||
if self._shouldResolve():
|
||||
return self.task_type
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest TPU information.
|
||||
|
||||
We retrieve the information from the GCE APIs every time this method is
|
||||
called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information returned from Cloud TPUs.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the provided TPU is not healthy.
|
||||
"""
|
||||
############################################################################
|
||||
# There are 5 potential cases this code must handle:
|
||||
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
|
||||
# a. Create a ClusterSpec that includes the coordinator job
|
||||
# b. Create a ClusterSpec without the coordinator job.
|
||||
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
|
||||
# tasks and
|
||||
# a. Create a ClusterSpec with the coordinator
|
||||
# b. Create a ClusterSpec without the coordinator
|
||||
# 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
|
||||
############################################################################
|
||||
|
||||
if self._shouldResolve():
|
||||
# Case 1.
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
service = self._tpuService()
|
||||
request = service.projects().locations().nodes().get(name=full_name)
|
||||
response = request.execute()
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(compat.as_text(self._tpu), response['state']))
|
||||
|
||||
if 'health' in response and response['health'] != 'HEALTHY':
|
||||
raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
|
||||
(compat.as_text(self._tpu), response['health']))
|
||||
|
||||
if 'networkEndpoints' in response:
|
||||
worker_list = [
|
||||
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
||||
for endpoint in response['networkEndpoints']
|
||||
]
|
||||
else:
|
||||
# Fall back to the deprecated response format
|
||||
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
|
||||
worker_list = [instance_url]
|
||||
|
||||
cluster_spec = {self.task_type: worker_list}
|
||||
else:
|
||||
if self.rpc_layer is None:
|
||||
# Case 3.
|
||||
return None
|
||||
# Case 2.
|
||||
tpus = []
|
||||
for tpu in self._tpu.split(_ENDPOINTS_SEPARATOR):
|
||||
# We are working around the fact that GKE environment variable that is
|
||||
# supplied to us has the protocol string embedded in it, but we want
|
||||
# to strip it out for the ClusterSpec.
|
||||
if (self.rpc_layer is not None and
|
||||
tpu.startswith(self.rpc_layer + '://')):
|
||||
tpus.append(tpu[len(self.rpc_layer + '://'):])
|
||||
else:
|
||||
tpus.append(tpu)
|
||||
cluster_spec = {self.task_type: tpus}
|
||||
|
||||
if self._coordinator_address:
|
||||
# {1, 2}.a
|
||||
cluster_spec[self._coordinator_name] = [self._coordinator_address]
|
||||
|
||||
return server_lib.ClusterSpec(cluster_spec)
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of TPU cores per worker.
|
||||
|
||||
This defaults to 8 for all current TPU configurations, and we do not need
|
||||
to query any remote systems for this.
|
||||
|
||||
Args:
|
||||
session_config: Unused. Not currently necessary to query anything as this
|
||||
number is 8 for all TPU configurations.
|
||||
"""
|
||||
del session_config # Unused. Not necessary to query anything.
|
||||
return 8
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
return self._environment
|
||||
|
||||
def _start_local_server(self):
|
||||
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
|
||||
self._server = server_lib.Server(
|
||||
{
|
||||
'local': ['0.0.0.0:0']
|
||||
}, protocol='grpc', config=None, start=True)
|
||||
# self._server.target is of the form: grpc://ipaddress:port
|
||||
target = compat.as_bytes(self._server.target)
|
||||
splits = target.split(compat.as_bytes(':'))
|
||||
assert len(splits) == 3, self._server.target
|
||||
assert splits[0] == compat.as_bytes('grpc'), self._server.target
|
||||
self._coordinator_port = compat.as_text(splits[2])
|
||||
self._coordinator_address = '%s:%s' % (
|
||||
address, compat.as_text(self._coordinator_port))
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
|
||||
return self
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -216,7 +216,7 @@ py_library(
|
||||
],
|
||||
deps = [
|
||||
":tpu_lib",
|
||||
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
|
||||
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
|
||||
"//tensorflow/contrib/distribute",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
|
||||
@ -264,7 +264,7 @@ py_library(
|
||||
":tpu_py",
|
||||
"//tensorflow/compiler/xla/experimental/xla_sharding",
|
||||
"//tensorflow/compiler/xla/python_api:xla_shape",
|
||||
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
|
||||
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
|
||||
"//tensorflow/contrib/compiler:xla",
|
||||
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
|
||||
"//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py",
|
||||
|
@ -3534,6 +3534,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Dependency added and used by ClusterResolvers to avoid circular dependency between keras, distribute, and training.
|
||||
py_library(
|
||||
name = "training_server_lib",
|
||||
srcs = ["training/server_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":pywrap_tensorflow",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saveable_object",
|
||||
srcs = ["training/saveable_object.py"],
|
||||
|
@ -124,6 +124,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
],
|
||||
|
180
tensorflow/python/distribute/cluster_resolver/BUILD
Normal file
180
tensorflow/python/distribute/cluster_resolver/BUILD
Normal file
@ -0,0 +1,180 @@
|
||||
# Description: Operations defined for Cluster Resolvers
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
py_library(
|
||||
name = "cluster_resolver_lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
":gce_cluster_resolver_py",
|
||||
":kubernetes_cluster_resolver_py",
|
||||
":slurm_cluster_resolver_py",
|
||||
":tfconfig_cluster_resolver_py",
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_cluster_resolver_py",
|
||||
srcs = ["cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gce_cluster_resolver_py",
|
||||
srcs = ["gce_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tfconfig_cluster_resolver_py",
|
||||
srcs = ["tfconfig_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = ["tpu_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "slurm_cluster_resolver_py",
|
||||
srcs = ["slurm_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kubernetes_cluster_resolver_py",
|
||||
srcs = ["kubernetes_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "base_cluster_resolver_py_test",
|
||||
srcs = ["cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
main = "cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "gce_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["gce_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":gce_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
main = "gce_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tfconfig_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["tfconfig_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tfconfig_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "tfconfig_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "tpu_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "slurm_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["slurm_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":slurm_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
main = "slurm_cluster_resolver_test.py",
|
||||
tags = [],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "kubernetes_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["kubernetes_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":kubernetes_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
main = "kubernetes_cluster_resolver_test.py",
|
||||
)
|
57
tensorflow/python/distribute/cluster_resolver/__init__.py
Normal file
57
tensorflow/python/distribute/cluster_resolver/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Library Imports for Cluster Resolvers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.distribute.cluster_resolver import gce_cluster_resolver
|
||||
from tensorflow.python.distribute.cluster_resolver import kubernetes_cluster_resolver
|
||||
from tensorflow.python.distribute.cluster_resolver import slurm_cluster_resolver
|
||||
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'cluster_resolver',
|
||||
'gce_cluster_resolver',
|
||||
'kubernetes_cluster_resolver',
|
||||
'slurm_cluster_resolver',
|
||||
'tfconfig_cluster_resolver',
|
||||
'tpu_cluster_resolver',
|
||||
'ClusterResolver',
|
||||
'SimpleClusterResolver',
|
||||
'UnionClusterResolver',
|
||||
'GceClusterResolver',
|
||||
'KubernetesClusterResolver',
|
||||
'TFConfigClusterResolver',
|
||||
'TPUClusterResolver',
|
||||
'SlurmClusterResolver',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -0,0 +1,374 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
def format_master_url(master, rpc_layer=None):
|
||||
if rpc_layer:
|
||||
return '%s://%s' % (rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class ClusterResolver(object):
|
||||
"""Abstract class for all implementations of ClusterResolvers.
|
||||
|
||||
This defines the skeleton for all implementations of ClusterResolvers.
|
||||
ClusterResolvers are a way for TensorFlow to communicate with various cluster
|
||||
management systems (e.g. GCE, AWS, etc...).
|
||||
|
||||
By letting TensorFlow communicate with these systems, we will be able to
|
||||
automatically discover and resolve IP addresses for various TensorFlow
|
||||
workers. This will eventually allow us to automatically recover from
|
||||
underlying machine failures and scale TensorFlow worker clusters up and down.
|
||||
|
||||
Note to Implementors: In addition to these abstract methods, you must also
|
||||
implement the task_type, task_index, and rpc_layer attributes. You may choose
|
||||
to implement them either as properties with getters or setters or directly
|
||||
set the attributes.
|
||||
|
||||
- task_type is the name of the server's current named job (e.g. 'worker',
|
||||
'ps' in a distributed parameterized training job).
|
||||
- task_index is the ordinal index of the server within the task type.
|
||||
- rpc_layer is the protocol used by TensorFlow to communicate with other
|
||||
TensorFlow servers in a distributed environment.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def cluster_spec(self):
|
||||
"""Retrieve the current state of the cluster and returns a ClusterSpec.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec representing the state of the cluster at the moment this
|
||||
function is called.
|
||||
|
||||
Implementors of this function must take care in ensuring that the
|
||||
ClusterSpec returned is up-to-date at the time of calling this function.
|
||||
This usually means retrieving the information from the underlying cluster
|
||||
management system every time this function is invoked and reconstructing
|
||||
a cluster_spec, rather than attempting to cache anything.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Retrieves the name or URL of the session master.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
|
||||
Implementors of this function must take care in ensuring that the master
|
||||
returned is up-to-date at the time to calling this function. This usually
|
||||
means retrieving the master every time this function is invoked.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of accelerator cores per worker.
|
||||
|
||||
This returns the number of accelerator cores (such as GPUs and TPUs)
|
||||
available per worker. If workers only has CPU cores available, then this
|
||||
should return 0. This method will query the master for this information
|
||||
if it is not otherwise known.
|
||||
|
||||
Args:
|
||||
session_config: (Optional) Configuration for starting a new session to
|
||||
query how many accelerator cores it has.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractproperty
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SimpleClusterResolver(ClusterResolver):
|
||||
"""Simple implementation of ClusterResolver that accepts a ClusterSpec."""
|
||||
|
||||
def __init__(self, cluster_spec, master='', task_type=None, task_index=None,
|
||||
environment='', num_accelerators_per_worker=0,
|
||||
rpc_layer=None):
|
||||
"""Creates a SimpleClusterResolver from a ClusterSpec."""
|
||||
super(SimpleClusterResolver, self).__init__()
|
||||
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._environment = environment
|
||||
self._num_accelerators_per_worker = num_accelerators_per_worker
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
if not isinstance(cluster_spec, ClusterSpec):
|
||||
raise TypeError('cluster_spec must be a ClusterSpec.')
|
||||
self._cluster_spec = cluster_spec
|
||||
|
||||
if not isinstance(master, str):
|
||||
raise TypeError('master must be a string.')
|
||||
self._master = master
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns the ClusterSpec passed into the constructor."""
|
||||
return self._cluster_spec
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC used by distributed TensorFlow.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
|
||||
If a task_type and task_index is given, this will override the `master`
|
||||
string passed into the initialization function.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
else:
|
||||
master = self._master
|
||||
|
||||
return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer)
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._environment
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of accelerator cores per worker.
|
||||
|
||||
Args:
|
||||
session_config: Unused. The SimpleClusterResolver does not do automatic
|
||||
detection of accelerators, so a TensorFlow session will never be
|
||||
created, and thus a `session_config` is never necessary here, and will
|
||||
be ignored.
|
||||
"""
|
||||
del session_config
|
||||
return self._num_accelerators_per_worker
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
|
||||
class UnionClusterResolver(ClusterResolver):
|
||||
"""Performs a union on underlying ClusterResolvers.
|
||||
|
||||
This class performs a union given two or more existing ClusterResolvers. It
|
||||
merges the underlying ClusterResolvers, and returns one unified ClusterSpec
|
||||
when cluster_spec is called. The details of the merge function is
|
||||
documented in the cluster_spec function.
|
||||
|
||||
For additional Cluster Resolver properties such as task type, task index,
|
||||
rpc layer, environment, etc..., we will return the value from the first
|
||||
ClusterResolver in the union.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initializes a UnionClusterResolver with other ClusterResolvers.
|
||||
|
||||
Args:
|
||||
*args: `ClusterResolver` objects to be unionized.
|
||||
**kwargs:
|
||||
rpc_layer - (Optional) Override value for the RPC layer used by
|
||||
TensorFlow.
|
||||
task_type - (Optional) Override value for the current task type.
|
||||
task_index - (Optional) Override value for the current task index.
|
||||
|
||||
Raises:
|
||||
TypeError: If any argument is not a subclass of `ClusterResolvers`.
|
||||
ValueError: If there are no arguments passed.
|
||||
"""
|
||||
super(UnionClusterResolver, self).__init__()
|
||||
|
||||
self._rpc_layer = kwargs.pop('rpc_layer', None)
|
||||
self._task_type = kwargs.pop('task_type', None)
|
||||
self._task_index = kwargs.pop('task_index', None)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs))
|
||||
|
||||
if not args:
|
||||
raise ValueError('At least one ClusterResolver is required.')
|
||||
|
||||
for cluster_resolver in args:
|
||||
if not isinstance(cluster_resolver, ClusterResolver):
|
||||
raise TypeError('All arguments must be a sub-class of '
|
||||
'`ClusterResolver.`')
|
||||
self._cluster_resolvers = args
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a union of all the ClusterSpecs from the ClusterResolvers.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information merged from all the underlying
|
||||
ClusterResolvers.
|
||||
|
||||
Raises:
|
||||
KeyError: If there are conflicting keys detected when merging two or
|
||||
more dictionaries, this exception is raised.
|
||||
|
||||
Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the
|
||||
same job name, we will merge the list/dict of workers.
|
||||
|
||||
If *all* underlying ClusterSpecs expose the set of workers as lists, we will
|
||||
concatenate the lists of workers, starting with the list of workers from
|
||||
the first ClusterResolver passed into the constructor.
|
||||
|
||||
If *any* of the ClusterSpecs expose the set of workers as a dict, we will
|
||||
treat all the sets of workers as dicts (even if they are returned as lists)
|
||||
and will only merge them into a dict if there is no conflicting keys. If
|
||||
there is a conflicting key, we will raise a `KeyError`.
|
||||
"""
|
||||
|
||||
merged_cluster = {}
|
||||
|
||||
# We figure out whether it is all lists for a particular job, or whether
|
||||
# there are dicts inside.
|
||||
for cluster_resolver in self._cluster_resolvers:
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
cluster_dict = cluster_spec.as_dict()
|
||||
|
||||
for job_name, tasks in cluster_dict.items():
|
||||
if job_name in merged_cluster:
|
||||
# If we see a dict, then we write a dict out regardless.
|
||||
if isinstance(tasks, dict):
|
||||
merged_cluster[job_name] = {}
|
||||
else:
|
||||
# We take whichever type is present.
|
||||
if isinstance(tasks, list):
|
||||
merged_cluster[job_name] = []
|
||||
else:
|
||||
merged_cluster[job_name] = {}
|
||||
|
||||
# We then do the merge as appropriate in merged_cluster[job].
|
||||
for cluster_resolver in self._cluster_resolvers:
|
||||
cluster_spec = cluster_resolver.cluster_spec()
|
||||
cluster_dict = cluster_spec.as_dict()
|
||||
|
||||
for job_name, tasks in cluster_dict.items():
|
||||
if isinstance(merged_cluster[job_name], list):
|
||||
# We all have lists, we can just concatenate and be done.
|
||||
merged_cluster[job_name].extend(tasks)
|
||||
else:
|
||||
if isinstance(tasks, list):
|
||||
# We convert to a dictionary if the type is a list.
|
||||
task_dict = dict(zip(range(0, len(tasks)), tasks))
|
||||
else:
|
||||
# We can simply make a copy (for update) and be done.
|
||||
task_dict = tasks.copy()
|
||||
|
||||
# We detect if there are duplicates, and raise an error if so.
|
||||
task_keys = set(task_dict)
|
||||
merged_keys = set(merged_cluster[job_name].keys())
|
||||
intersected_keys = task_keys.intersection(merged_keys)
|
||||
if intersected_keys:
|
||||
raise KeyError('Duplicate keys detected when merging two '
|
||||
'ClusterSpecs: %s' % repr(intersected_keys))
|
||||
|
||||
# We do the merge after all the processing.
|
||||
merged_cluster[job_name].update(task_dict)
|
||||
|
||||
return ClusterSpec(merged_cluster)
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
This usually returns the master from the first ClusterResolver passed in,
|
||||
but you can override this by specifying the task_type and task_index.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
return format_master_url(master, rpc_layer or self._rpc_layer)
|
||||
|
||||
return self._cluster_resolvers[0].master(rpc_layer=rpc_layer)
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type or self._cluster_resolvers[0].task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index or self._cluster_resolvers[0].task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._cluster_resolvers[0].environment
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
return self._cluster_resolvers[0].num_accelerators_per_worker(
|
||||
session_config)
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer or self._cluster_resolvers[0].rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
@ -18,8 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
@ -0,0 +1,206 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for GCE Instance Groups."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
|
||||
|
||||
def _format_master_url(master, rpc_layer=None):
|
||||
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
|
||||
|
||||
|
||||
class GceClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Compute Engine.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Compute Engine
|
||||
instance group platform. By specifying a project, zone, and instance group,
|
||||
this will retrieve the IP address of all the instances within the instance
|
||||
group and return a Cluster Resolver object suitable for use for distributed
|
||||
TensorFlow.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project,
|
||||
zone,
|
||||
instance_group,
|
||||
port,
|
||||
task_type='worker',
|
||||
task_index=0,
|
||||
rpc_layer='grpc',
|
||||
num_accelerators_per_worker=0,
|
||||
credentials='default',
|
||||
service=None):
|
||||
"""Creates a new GceClusterResolver object.
|
||||
|
||||
This takes in a few parameters and creates a GceClusterResolver project. It
|
||||
will then use these parameters to query the GCE API for the IP addresses of
|
||||
each instance in the instance group.
|
||||
|
||||
Args:
|
||||
project: Name of the GCE project.
|
||||
zone: Zone of the GCE instance group.
|
||||
instance_group: Name of the GCE instance group.
|
||||
port: Port of the listening TensorFlow server (default: 8470)
|
||||
task_type: Name of the TensorFlow job this GCE instance group of VM
|
||||
instances belong to.
|
||||
task_index: The task index for this particular VM, within the GCE
|
||||
instance group. In particular, every single instance should be assigned
|
||||
a unique ordinal index within an instance group manually so that they
|
||||
can be distinguished from each other.
|
||||
rpc_layer: The RPC layer TensorFlow should use to communicate across
|
||||
instances.
|
||||
num_accelerators_per_worker: Number of accelerators (GPUs) present per
|
||||
instance.
|
||||
credentials: GCE Credentials. If nothing is specified, this defaults to
|
||||
GoogleCredentials.get_application_default().
|
||||
service: The GCE API object returned by the googleapiclient.discovery
|
||||
function. (Default: discovery.build('compute', 'v1')). If you specify a
|
||||
custom service object, then the credentials parameter will be ignored.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
"""
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
self._instance_group = instance_group
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._rpc_layer = rpc_layer
|
||||
self._port = port
|
||||
self._credentials = credentials
|
||||
|
||||
if credentials == 'default':
|
||||
if _GOOGLE_API_CLIENT_INSTALLED:
|
||||
self._credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if service is None:
|
||||
if not _GOOGLE_API_CLIENT_INSTALLED:
|
||||
raise ImportError('googleapiclient must be installed before using the '
|
||||
'GCE cluster resolver')
|
||||
self._service = discovery.build(
|
||||
'compute', 'v1',
|
||||
credentials=self._credentials)
|
||||
else:
|
||||
self._service = service
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest instance group info.
|
||||
|
||||
This returns a ClusterSpec object for use based on information from the
|
||||
specified instance group. We will retrieve the information from the GCE APIs
|
||||
every time this method is called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information retrieved from GCE.
|
||||
"""
|
||||
request_body = {'instanceState': 'RUNNING'}
|
||||
request = self._service.instanceGroups().listInstances(
|
||||
project=self._project,
|
||||
zone=self._zone,
|
||||
instanceGroups=self._instance_group,
|
||||
body=request_body,
|
||||
orderBy='name')
|
||||
|
||||
worker_list = []
|
||||
|
||||
while request is not None:
|
||||
response = request.execute()
|
||||
|
||||
items = response['items']
|
||||
for instance in items:
|
||||
instance_name = instance['instance'].split('/')[-1]
|
||||
|
||||
instance_request = self._service.instances().get(
|
||||
project=self._project,
|
||||
zone=self._zone,
|
||||
instance=instance_name)
|
||||
|
||||
if instance_request is not None:
|
||||
instance_details = instance_request.execute()
|
||||
ip_address = instance_details['networkInterfaces'][0]['networkIP']
|
||||
instance_url = '%s:%s' % (ip_address, self._port)
|
||||
worker_list.append(instance_url)
|
||||
|
||||
request = self._service.instanceGroups().listInstances_next(
|
||||
previous_request=request,
|
||||
previous_response=response)
|
||||
|
||||
worker_list.sort()
|
||||
return ClusterSpec({self._task_type: worker_list})
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
task_type = task_type if task_type is not None else self._task_type
|
||||
task_index = task_index if task_index is not None else self._task_index
|
||||
|
||||
if task_type is not None and task_index is not None:
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
if rpc_layer or self._rpc_layer:
|
||||
return '%s://%s' % (rpc_layer or self._rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
|
||||
return ''
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
raise RuntimeError(
|
||||
'You cannot reset the task_type of the GceClusterResolver after it has '
|
||||
'been created.')
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the GCE environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
del session_config # Unused, since this is set manually in __init__.
|
||||
return self._num_accelerators_per_worker
|
@ -18,8 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import GceClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
@ -0,0 +1,173 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Kubernetes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
_KUBERNETES_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top
|
||||
from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_KUBERNETES_API_CLIENT_INSTALLED = False
|
||||
|
||||
|
||||
class KubernetesClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Kubernetes.
|
||||
|
||||
This is an implementation of cluster resolvers for Kubernetes. When given the
|
||||
the Kubernetes namespace and label selector for pods, we will retrieve the
|
||||
pod IP addresses of all running pods matching the selector, and return a
|
||||
ClusterSpec based on that information.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
job_to_label_mapping=None,
|
||||
tf_server_port=8470,
|
||||
rpc_layer='grpc',
|
||||
override_client=None):
|
||||
"""Initializes a new KubernetesClusterResolver.
|
||||
|
||||
This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver
|
||||
will attempt to talk to the Kubernetes master to retrieve all the instances
|
||||
of pods matching a label selector.
|
||||
|
||||
Args:
|
||||
job_to_label_mapping: A mapping of TensorFlow jobs to label selectors.
|
||||
This allows users to specify many TensorFlow jobs in one Cluster
|
||||
Resolver, and each job can have pods belong with different label
|
||||
selectors. For example, a sample mapping might be
|
||||
```
|
||||
{'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'],
|
||||
'ps': ['job-name=ps-1', 'job-name=ps-2']}
|
||||
```
|
||||
tf_server_port: The port the TensorFlow server is listening on.
|
||||
rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate
|
||||
between tasks in Kubernetes. Defaults to 'grpc'.
|
||||
override_client: The Kubernetes client (usually automatically retrieved
|
||||
using `from kubernetes import client as k8sclient`). If you pass this
|
||||
in, you are responsible for setting Kubernetes credentials manually.
|
||||
|
||||
Raises:
|
||||
ImportError: If the Kubernetes Python client is not installed and no
|
||||
`override_client` is passed in.
|
||||
RuntimeError: If autoresolve_task is not a boolean or a callable.
|
||||
"""
|
||||
if _KUBERNETES_API_CLIENT_INSTALLED:
|
||||
k8sconfig.load_kube_config()
|
||||
|
||||
if not job_to_label_mapping:
|
||||
job_to_label_mapping = {'worker': ['job-name=tensorflow']}
|
||||
|
||||
if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED:
|
||||
raise ImportError('The Kubernetes Python client must be installed before'
|
||||
'using the Kubernetes Cluster Resolver. To install the'
|
||||
'Kubernetes Python client, run `pip install '
|
||||
'kubernetes` on your command line.')
|
||||
|
||||
self._job_to_label_mapping = job_to_label_mapping
|
||||
self._tf_server_port = tf_server_port
|
||||
self._override_client = override_client
|
||||
|
||||
self.task_type = None
|
||||
self.task_index = None
|
||||
self.rpc_layer = rpc_layer
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a session.
|
||||
|
||||
You must have set the task_type and task_index object properties before
|
||||
calling this function, or pass in the `task_type` and `task_index`
|
||||
parameters when using this function. If you do both, the function parameters
|
||||
will override the object properties.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) The type of the TensorFlow task of the master.
|
||||
task_index: (Optional) The index of the TensorFlow task of the master.
|
||||
rpc_layer: (Optional) The RPC protocol for the given cluster.
|
||||
|
||||
Returns:
|
||||
The name or URL of the session master.
|
||||
"""
|
||||
if task_type is not None and task_index is not None:
|
||||
return format_master_url(
|
||||
self.cluster_spec().task_address(task_type, task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
if self.task_type is not None and self.task_index is not None:
|
||||
return format_master_url(
|
||||
self.cluster_spec().task_address(self.task_type, self.task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
return ''
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest info from Kubernetes.
|
||||
|
||||
We retrieve the information from the Kubernetes master every time this
|
||||
method is called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information returned from Kubernetes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any of the pods returned by the master is not in the
|
||||
`Running` phase.
|
||||
"""
|
||||
if not self._override_client:
|
||||
k8sconfig.load_kube_config()
|
||||
|
||||
client = self._override_client or k8sclient.CoreV1Api()
|
||||
cluster_map = {}
|
||||
|
||||
for tf_job in self._job_to_label_mapping:
|
||||
all_pods = []
|
||||
for selector in self._job_to_label_mapping[tf_job]:
|
||||
ret = client.list_pod_for_all_namespaces(label_selector=selector)
|
||||
selected_pods = []
|
||||
|
||||
# Sort the list by the name to make sure it doesn't change call to call.
|
||||
for pod in sorted(ret.items, key=lambda x: x.metadata.name):
|
||||
if pod.status.phase == 'Running':
|
||||
selected_pods.append(
|
||||
'%s:%s' % (pod.status.host_ip, self._tf_server_port))
|
||||
else:
|
||||
raise RuntimeError('Pod "%s" is not running; phase: "%s"' %
|
||||
(pod.metadata.name, pod.status.phase))
|
||||
all_pods.extend(selected_pods)
|
||||
cluster_map[tf_job] = all_pods
|
||||
|
||||
return server_lib.ClusterSpec(cluster_map)
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the Cloud environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
local_devices = device_lib.list_local_devices(session_config)
|
||||
return len([d for d in local_devices if d.device_type == 'GPU'])
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training import KubernetesClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import KubernetesClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
@ -0,0 +1,226 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Slurm workload manager."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
|
||||
class SlurmClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for system with Slurm workload manager.
|
||||
|
||||
This is an implementation of cluster resolvers for Slurm clusters. This allows
|
||||
the specification of jobs and task counts, number of tasks per node, number of
|
||||
GPUs on each node and number of GPUs for each task, It retrieves system
|
||||
attributes by Slurm environment variables, resolves allocated computing node
|
||||
names, construct a cluster and return a Cluster Resolver object which an be
|
||||
use for distributed TensorFlow.
|
||||
"""
|
||||
|
||||
def _resolve_hostnames(self):
|
||||
"""Resolve host names of nodes allocated in current jobs.
|
||||
|
||||
Returns:
|
||||
A list of node names as strings.
|
||||
"""
|
||||
hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']).
|
||||
decode('utf-8').strip().split('\n'))
|
||||
return hostlist
|
||||
|
||||
def __init__(self,
|
||||
jobs,
|
||||
port_base=8888,
|
||||
gpus_per_node=1,
|
||||
gpus_per_task=1,
|
||||
tasks_per_node=None,
|
||||
auto_set_gpu=True,
|
||||
rpc_layer='grpc'):
|
||||
"""Creates a new SlurmClusterResolver object.
|
||||
|
||||
This takes in parameters and creates a SlurmClusterResolver object. It uses
|
||||
those parameters to check which nodes will processes reside and resolves
|
||||
their hostnames. With the number of the GPUs on each node and number of GPUs
|
||||
for each task it offsets the port number for each processes and allocate
|
||||
GPUs to tasks by setting environment variables. The resolver currently
|
||||
supports homogeneous tasks and default Slurm process allocation.
|
||||
|
||||
Args:
|
||||
jobs: Dictionary with job names as key and number of tasks in the job as
|
||||
value
|
||||
port_base: The first port number to start with for processes on a node.
|
||||
gpus_per_node: Number of GPUs available on each node.
|
||||
gpus_per_task: Number of GPUs to be used for each task.
|
||||
tasks_per_node: Number of tasks to run on each node, if not set defaults
|
||||
to Slurm's output environment variable SLURM_NTASKS_PER_NODE.
|
||||
auto_set_gpu: Set the visible CUDA devices automatically while resolving
|
||||
the cluster by setting CUDA_VISIBLE_DEVICES environment variable.
|
||||
Defaults to True.
|
||||
rpc_layer: (Optional) The protocol TensorFlow uses to communicate between
|
||||
nodes. Defaults to 'grpc'.
|
||||
|
||||
Returns:
|
||||
A ClusterResolver object which can be used with distributed TensorFlow.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If requested more GPUs per node then available or requested
|
||||
more tasks then assigned tasks.
|
||||
"""
|
||||
|
||||
# check if launched by mpirun
|
||||
if 'OMPI_COMM_WORLD_RANK' in os.environ:
|
||||
self._rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
else:
|
||||
self._rank = int(os.environ['SLURM_PROCID'])
|
||||
num_tasks = int(os.environ['SLURM_NTASKS'])
|
||||
|
||||
self._jobs = collections.OrderedDict(sorted(jobs.items()))
|
||||
self._port_base = port_base
|
||||
|
||||
# user specification overrides SLURM specification
|
||||
if tasks_per_node is not None:
|
||||
self._tasks_per_node = tasks_per_node
|
||||
elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ:
|
||||
self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
|
||||
else:
|
||||
raise RuntimeError('Neither `tasks_per_node` or '
|
||||
'SLURM_NTASKS_PER_NODE is set.')
|
||||
|
||||
self._gpus_per_node = gpus_per_node
|
||||
self._gpus_per_task = gpus_per_task
|
||||
|
||||
self._auto_set_gpu = auto_set_gpu
|
||||
self.task_type = None
|
||||
self.task_index = None
|
||||
self.rpc_layer = rpc_layer
|
||||
|
||||
self._gpu_allocation = []
|
||||
self._cluster_allocation = {}
|
||||
|
||||
if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node:
|
||||
raise RuntimeError('Requested more GPUs per node then available.')
|
||||
|
||||
if sum(self._jobs.values()) != num_tasks:
|
||||
raise RuntimeError('Requested more tasks then assigned tasks.')
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest instance group info.
|
||||
|
||||
This returns a ClusterSpec object for use based on information from the
|
||||
specified initialization parameters and Slurm environment variables. The
|
||||
cluster specification is resolved each time this function is called. The
|
||||
resolver extract hostnames of nodes by scontrol and pack tasks in that
|
||||
order until a node a has number of tasks that is equal to specification.
|
||||
GPUs on nodes are allocated to tasks by specification through setting
|
||||
CUDA_VISIBLE_DEVICES environment variable.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information retrieved from Slurm's
|
||||
environment variables.
|
||||
"""
|
||||
hostlist = self._resolve_hostnames()
|
||||
|
||||
task_list = []
|
||||
self._gpu_allocation = []
|
||||
self._cluster_allocation = {}
|
||||
|
||||
for host in hostlist:
|
||||
for port_offset, gpu_offset in zip(
|
||||
range(self._tasks_per_node),
|
||||
range(0, self._gpus_per_node, self._gpus_per_task)):
|
||||
|
||||
host_addr = '%s:%d' % (host, self._port_base + port_offset)
|
||||
task_list.append(host_addr)
|
||||
gpu_id_list = []
|
||||
|
||||
for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task):
|
||||
gpu_id_list.append(str(gpu_id))
|
||||
|
||||
self._gpu_allocation.append(','.join(gpu_id_list))
|
||||
|
||||
cluster_rank_offset_start = 0
|
||||
cluster_rank_offset_end = 0
|
||||
|
||||
for task_type, num_tasks in self._jobs.items():
|
||||
cluster_rank_offset_end = cluster_rank_offset_start + num_tasks
|
||||
|
||||
self._cluster_allocation[task_type] = (
|
||||
task_list[cluster_rank_offset_start:cluster_rank_offset_end])
|
||||
|
||||
if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end:
|
||||
self.task_type = task_type
|
||||
self.task_index = self._rank - cluster_rank_offset_start
|
||||
|
||||
cluster_rank_offset_start = cluster_rank_offset_end
|
||||
|
||||
if self._auto_set_gpu is True:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank]
|
||||
|
||||
return ClusterSpec(self._cluster_allocation)
|
||||
|
||||
def get_task_info(self):
|
||||
"""Returns job name and task_index for the process which calls this.
|
||||
|
||||
This returns the job name and task index for the process which calls this
|
||||
function according to its rank and cluster specification. The job name and
|
||||
task index are set after a cluster is constructed by cluster_spec otherwise
|
||||
defaults to None.
|
||||
|
||||
Returns:
|
||||
A string specifying job name the process belongs to and an integner
|
||||
specifying the task index the process belongs to in that job.
|
||||
"""
|
||||
return self.task_type, self.task_index
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master string for connecting to a TensorFlow master.
|
||||
|
||||
Args:
|
||||
task_type: (Optional) Overrides the default auto-selected task type.
|
||||
task_index: (Optional) Overrides the default auto-slected task index.
|
||||
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
|
||||
to communicate across nodes.
|
||||
|
||||
Returns:
|
||||
A connection string for connecting to a TensorFlow master.
|
||||
"""
|
||||
task_type = task_type if task_type is not None else self.task_type
|
||||
task_index = task_index if task_index is not None else self.task_index
|
||||
rpc_layer = rpc_layer or self.rpc_layer
|
||||
master = self.cluster_spec().task_address(task_type, task_index)
|
||||
|
||||
return '%s://%s' % (rpc_layer, master) if rpc_layer else master
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in.
|
||||
|
||||
For users in the Slurm environment, the environment property is always an
|
||||
empty string, and Google users will not use this ClusterResolver for running
|
||||
on internal systems.
|
||||
"""
|
||||
return ''
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
del session_config # Unused, since this is set in __init__ manually.
|
||||
return self._gpus_per_node
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import SlurmClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
@ -0,0 +1,171 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
|
||||
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
_TF_CONFIG_ENV = 'TF_CONFIG'
|
||||
_SESSION_MASTER_KEY = 'session_master'
|
||||
_RPC_LAYER_KEY = 'rpc_layer'
|
||||
_TASK_KEY = 'task'
|
||||
|
||||
|
||||
def format_master_url(master, rpc_layer=None):
|
||||
if rpc_layer:
|
||||
return '%s://%s' % (rpc_layer, master)
|
||||
else:
|
||||
return master
|
||||
|
||||
|
||||
def _load_tf_config():
|
||||
return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
|
||||
|
||||
def _get_value_in_tfconfig(key, default=None):
|
||||
tf_config = _load_tf_config()
|
||||
return tf_config[key] if key in tf_config else default
|
||||
|
||||
|
||||
class TFConfigClusterResolver(ClusterResolver):
|
||||
"""Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar."""
|
||||
|
||||
def __init__(self,
|
||||
task_type=None,
|
||||
task_index=None,
|
||||
rpc_layer=None,
|
||||
environment=None,
|
||||
num_accelerators_per_worker=0):
|
||||
"""Creates a new TFConfigClusterResolver.
|
||||
|
||||
Args:
|
||||
task_type: (String, optional) Overrides the task type specified in the
|
||||
TF_CONFIG environment variable.
|
||||
task_index: (Integer, optional) Overrides the task index specified in the
|
||||
TF_CONFIG environment variable.
|
||||
rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
|
||||
environment: (String, optional) Overrides the environment TensorFlow
|
||||
operates in.
|
||||
num_accelerators_per_worker: (Integer, optional) Specifies the number of
|
||||
accelerators (e.g. GPUs, TPUs, others) that each node has.
|
||||
"""
|
||||
|
||||
self._task_type = task_type
|
||||
self._task_index = task_index
|
||||
self._rpc_layer = rpc_layer
|
||||
self._environment = environment
|
||||
self._num_accelerators_per_worker = num_accelerators_per_worker
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
if self._task_type is None:
|
||||
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
|
||||
return task_info['type'] if 'type' in task_info else None
|
||||
else:
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def task_index(self):
|
||||
if self._task_type is None:
|
||||
task_info = _get_value_in_tfconfig(_TASK_KEY, {})
|
||||
return task_info['index'] if 'index' in task_info else None
|
||||
else:
|
||||
return self._task_index
|
||||
|
||||
@task_type.setter
|
||||
def task_type(self, task_type):
|
||||
self._task_type = task_type
|
||||
|
||||
@task_index.setter
|
||||
def task_index(self, task_index):
|
||||
self._task_index = task_index
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
return self._environment
|
||||
|
||||
@property
|
||||
def rpc_layer(self):
|
||||
if self._rpc_layer is None:
|
||||
return _get_value_in_tfconfig(_RPC_LAYER_KEY)
|
||||
else:
|
||||
return self._rpc_layer
|
||||
|
||||
@rpc_layer.setter
|
||||
def rpc_layer(self, rpc_layer):
|
||||
self._rpc_layer = rpc_layer
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
# TODO(frankchn): Connect to server (w/ session_config) in the future.
|
||||
del session_config # Unused, we do not connect to another server here.
|
||||
return self._num_accelerators_per_worker
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec based on the TF_CONFIG environment variable.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec with information from the TF_CONFIG environment variable.
|
||||
"""
|
||||
tf_config = _load_tf_config()
|
||||
if 'cluster' not in tf_config:
|
||||
return ClusterSpec({})
|
||||
return ClusterSpec(tf_config['cluster'])
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Returns the master address to use when creating a TensorFlow session.
|
||||
|
||||
Args:
|
||||
task_type: (String, optional) Overrides and sets the task_type of the
|
||||
master.
|
||||
task_index: (Integer, optional) Overrides and sets the task id of the
|
||||
master.
|
||||
rpc_layer: (String, optional) Overrides and sets the protocol over which
|
||||
TensorFlow nodes communicate with each other.
|
||||
|
||||
Returns:
|
||||
The address of the master.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the task_type or task_id is not specified and the
|
||||
`TF_CONFIG` environment variable does not contain a task section.
|
||||
"""
|
||||
|
||||
# If `session_master` is set, just use that.
|
||||
session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
|
||||
if session_master is not None:
|
||||
return session_master
|
||||
|
||||
# Return an empty string if we are the only job in the ClusterSpec.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if (not cluster_spec.jobs or
|
||||
(len(cluster_spec.jobs) == 1 and
|
||||
len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
|
||||
return ''
|
||||
|
||||
# We try to auto-detect the task type and id, but uses the user-supplied one
|
||||
# where available
|
||||
task_type = task_type if task_type is not None else self.task_type
|
||||
task_index = task_index if task_index is not None else self.task_index
|
||||
|
||||
return format_master_url(cluster_spec.task_address(task_type, task_index),
|
||||
self.rpc_layer)
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
@ -0,0 +1,423 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Cloud TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from six.moves.urllib.request import Request
|
||||
from six.moves.urllib.request import urlopen
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
|
||||
|
||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||
_ENDPOINTS_SEPARATOR = ','
|
||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
|
||||
|
||||
class TPUClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||
service. As Cloud TPUs are in alpha, you will need to specify a API definition
|
||||
file for this to consume, in addition to a list of Cloud TPUs in your Google
|
||||
Cloud Platform project.
|
||||
"""
|
||||
|
||||
def _tpuService(self):
|
||||
"""Creates a new Cloud TPU API object.
|
||||
|
||||
This works around an issue where the underlying HTTP connection sometimes
|
||||
times out when the script has been running for too long. Other methods in
|
||||
this object calls this method to get a new API object whenever they need
|
||||
to communicate with the Cloud API.
|
||||
|
||||
Returns:
|
||||
A Google Cloud TPU API object.
|
||||
"""
|
||||
if self._service:
|
||||
return self._service
|
||||
|
||||
credentials = self._credentials
|
||||
if credentials is None or credentials == 'default':
|
||||
credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if self._discovery_url:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials,
|
||||
discoveryServiceUrl=self._discovery_url)
|
||||
else:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials)
|
||||
|
||||
def _requestComputeMetadata(self, path):
|
||||
req = Request('http://metadata/computeMetadata/v1/%s' % path,
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
|
||||
def _shouldResolve(self):
|
||||
if isinstance(self._should_resolve_override, bool):
|
||||
return self._should_resolve_override
|
||||
if (self._tpu == compat.as_bytes('') or
|
||||
self._tpu == compat.as_bytes('local') or
|
||||
self._tpu.startswith(compat.as_bytes('/bns')) or
|
||||
self._tpu.startswith(compat.as_bytes('localhost:')) or
|
||||
self._tpu.startswith(compat.as_bytes('grpc://'))):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _inGke():
|
||||
"""When running in GKE, the environment variable will be set."""
|
||||
return _GKE_ENV_VARIABLE in os.environ
|
||||
|
||||
@staticmethod
|
||||
def _gkeEndpoints():
|
||||
return os.environ[_GKE_ENV_VARIABLE]
|
||||
|
||||
@staticmethod
|
||||
def _envVarFallback():
|
||||
if _DEFAULT_ENV_VARIABLE in os.environ:
|
||||
return os.environ[_DEFAULT_ENV_VARIABLE]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _environmentDiscoveryUrl():
|
||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||
|
||||
def __init__(self,
|
||||
tpu=None,
|
||||
zone=None,
|
||||
project=None,
|
||||
job_name='worker',
|
||||
coordinator_name=None,
|
||||
coordinator_address=None,
|
||||
credentials='default',
|
||||
service=None,
|
||||
discovery_url=None):
|
||||
"""Creates a new TPUClusterResolver object.
|
||||
|
||||
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
|
||||
for the IP addresses and ports of each Cloud TPU listed.
|
||||
|
||||
Args:
|
||||
tpu: Either a string, or a list of strings corresponding to the TPUs to
|
||||
use. If the single string is the empty string, the string 'local', or a
|
||||
string that begins with 'grpc://' or '/bns', then it is assumed to not
|
||||
correspond with a Cloud TPU and will instead be passed as the session
|
||||
master and no ClusterSpec propagation will be done.
|
||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||
will try to discover from the GCE metadata service.
|
||||
project: Name of the GCP project containing Cloud TPUs. If omitted or
|
||||
empty, we will try to discover the project name of the GCE VM from the
|
||||
GCE metadata service.
|
||||
job_name: Name of the TensorFlow job the TPUs belong to.
|
||||
coordinator_name: The name to use for the coordinator. Set to None if the
|
||||
coordinator should not be included in the computed ClusterSpec.
|
||||
coordinator_address: The address of the coordinator (typically an ip:port
|
||||
pair). If set to None, a TF server will be started. If coordinator_name
|
||||
is None, a TF server will not be started even if coordinator_address is
|
||||
None.
|
||||
credentials: GCE Credentials. If None, then we use default credentials
|
||||
from the oauth2client
|
||||
service: The GCE API object returned by the googleapiclient.discovery
|
||||
function. If you specify a custom service object, then the credentials
|
||||
parameter will be ignored.
|
||||
discovery_url: A URL template that points to the location of
|
||||
the discovery service. It should have two parameters {api} and
|
||||
{apiVersion} that when filled in produce an absolute URL to the
|
||||
discovery document for that service. The environment variable
|
||||
'TPU_API_DISCOVERY_URL' will override this.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
ValueError: If no TPUs are specified.
|
||||
"""
|
||||
if isinstance(tpu, list):
|
||||
if not tpu:
|
||||
raise ValueError('At least one TPU must be specified.')
|
||||
if len(tpu) != 1:
|
||||
raise NotImplementedError(
|
||||
'Using multiple TPUs in a single session is not yet implemented')
|
||||
tpu = tpu[0]
|
||||
|
||||
in_gke = self._inGke()
|
||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
||||
if tpu is None:
|
||||
if in_gke:
|
||||
tpu = self._gkeEndpoints()
|
||||
else:
|
||||
tpu = self._envVarFallback()
|
||||
|
||||
if tpu is None:
|
||||
raise ValueError('Please provide a TPU Name to connect to.')
|
||||
|
||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||
|
||||
# By default the task_type is 'worker` and the task_index is 0 (which is the
|
||||
# first worker in the task).
|
||||
self.task_type = job_name
|
||||
self.task_index = 0
|
||||
|
||||
if tpu.startswith('grpc://'):
|
||||
# Cloud environment, where we are using GRPC to communicate to TPUs.
|
||||
self._environment = ''
|
||||
elif tpu == 'local' or not tpu:
|
||||
# Google environment, where the TPU is attached to the host.
|
||||
self._environment = 'google'
|
||||
elif tpu.startswith('/bns'):
|
||||
# Google environment, where we reach the TPU through BNS.
|
||||
self._environment = 'google'
|
||||
|
||||
# If TPU is in the Google environment or exists locally, we don't use any
|
||||
# RPC layer.
|
||||
if tpu.startswith('/bns') or tpu == 'local' or not tpu:
|
||||
self.rpc_layer = None
|
||||
else:
|
||||
self.rpc_layer = 'grpc'
|
||||
|
||||
# Setting this overrides the return value of self._shouldResolve()
|
||||
self._should_resolve_override = None
|
||||
|
||||
# We strip out the protocol if it is included, and override the
|
||||
# shouldResolve function to never resolve. We are adding the protocol back
|
||||
# in later in self.master().
|
||||
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
|
||||
tpu = tpu[len(self.rpc_layer + '://'):]
|
||||
self._tpu = tpu
|
||||
self._should_resolve_override = False
|
||||
|
||||
# Whether we should actually attempt to contact Cloud APIs
|
||||
should_resolve = self._shouldResolve()
|
||||
|
||||
# We error out if we are in a non-Cloud environment which cannot talk to the
|
||||
# Cloud APIs using the standard class and a special object is not passed in.
|
||||
self._service = service
|
||||
if (self._service is None and should_resolve and
|
||||
not _GOOGLE_API_CLIENT_INSTALLED):
|
||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||
'before using the TPU cluster resolver. Execute: '
|
||||
'`pip install --upgrade google-api-python-client` '
|
||||
'and `pip install --upgrade oauth2client` to '
|
||||
'install with pip.')
|
||||
|
||||
# We save user-passed credentials, unless the user didn't pass in anything.
|
||||
self._credentials = credentials
|
||||
if (credentials == 'default' and should_resolve and
|
||||
_GOOGLE_API_CLIENT_INSTALLED):
|
||||
self._credentials = None
|
||||
|
||||
# Automatically detect project and zone if unspecified.
|
||||
if not project and should_resolve:
|
||||
project = compat.as_str(
|
||||
self._requestComputeMetadata('project/project-id'))
|
||||
if not zone and should_resolve:
|
||||
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
|
||||
zone = zone_path.split('/')[-1]
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
|
||||
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
|
||||
|
||||
self._coordinator_name = coordinator_name
|
||||
if (coordinator_name and not coordinator_address and
|
||||
(should_resolve or in_gke)):
|
||||
self._start_local_server()
|
||||
else:
|
||||
self._coordinator_address = coordinator_address
|
||||
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
"""Get the Master string to be used for the session.
|
||||
|
||||
In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
|
||||
first instance in the ClusterSpec returned by the cluster_spec function.
|
||||
|
||||
If a non-TPU name is used when constructing a TPUClusterResolver, that will
|
||||
be returned instead (e.g. If the tpus argument's value when constructing
|
||||
this TPUClusterResolver was 'grpc://10.240.1.2:8470',
|
||||
'grpc://10.240.1.2:8470' will be returned).
|
||||
|
||||
Args:
|
||||
task_type: (Optional, string) The type of the TensorFlow task of the
|
||||
master.
|
||||
task_index: (Optional, integer) The index of the TensorFlow task of the
|
||||
master.
|
||||
rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
|
||||
communicate with TPUs.
|
||||
|
||||
Returns:
|
||||
string, the connection string to use when creating a session.
|
||||
|
||||
Raises:
|
||||
ValueError: If none of the TPUs specified exists.
|
||||
"""
|
||||
if self._shouldResolve():
|
||||
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if task_type is not None and task_index is not None:
|
||||
# task_type and task_index is from the function parameter
|
||||
master = cluster_spec.task_address(task_type, task_index)
|
||||
elif self.task_type is not None and self.task_index is not None:
|
||||
# task_type and task_index is from the object
|
||||
master = cluster_spec.task_address(self.task_type, self.task_index)
|
||||
else:
|
||||
# by default we take the first item in the cluster with the right name
|
||||
job_tasks = cluster_spec.job_tasks(self.task_type)
|
||||
if not job_tasks:
|
||||
raise ValueError('No TPUs with the specified names exist.')
|
||||
master = job_tasks[0]
|
||||
else:
|
||||
if isinstance(self._tpu, (bytes, bytearray)):
|
||||
master = self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
|
||||
else:
|
||||
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
|
||||
return format_master_url(master, rpc_layer or self.rpc_layer)
|
||||
|
||||
def get_master(self):
|
||||
return self.master()
|
||||
|
||||
def get_job_name(self):
|
||||
if self._shouldResolve():
|
||||
return self.task_type
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest TPU information.
|
||||
|
||||
We retrieve the information from the GCE APIs every time this method is
|
||||
called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information returned from Cloud TPUs.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the provided TPU is not healthy.
|
||||
"""
|
||||
############################################################################
|
||||
# There are 5 potential cases this code must handle:
|
||||
# 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
|
||||
# a. Create a ClusterSpec that includes the coordinator job
|
||||
# b. Create a ClusterSpec without the coordinator job.
|
||||
# 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
|
||||
# tasks and
|
||||
# a. Create a ClusterSpec with the coordinator
|
||||
# b. Create a ClusterSpec without the coordinator
|
||||
# 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
|
||||
############################################################################
|
||||
|
||||
if self._shouldResolve():
|
||||
# Case 1.
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
service = self._tpuService()
|
||||
request = service.projects().locations().nodes().get(name=full_name)
|
||||
response = request.execute()
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(compat.as_text(self._tpu), response['state']))
|
||||
|
||||
if 'health' in response and response['health'] != 'HEALTHY':
|
||||
raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
|
||||
(compat.as_text(self._tpu), response['health']))
|
||||
|
||||
if 'networkEndpoints' in response:
|
||||
worker_list = [
|
||||
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
||||
for endpoint in response['networkEndpoints']
|
||||
]
|
||||
else:
|
||||
# Fall back to the deprecated response format
|
||||
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
|
||||
worker_list = [instance_url]
|
||||
|
||||
cluster_spec = {self.task_type: worker_list}
|
||||
else:
|
||||
if self.rpc_layer is None:
|
||||
# Case 3.
|
||||
return None
|
||||
# Case 2.
|
||||
tpus = []
|
||||
for tpu in self._tpu.split(_ENDPOINTS_SEPARATOR):
|
||||
# We are working around the fact that GKE environment variable that is
|
||||
# supplied to us has the protocol string embedded in it, but we want
|
||||
# to strip it out for the ClusterSpec.
|
||||
if (self.rpc_layer is not None and
|
||||
tpu.startswith(self.rpc_layer + '://')):
|
||||
tpus.append(tpu[len(self.rpc_layer + '://'):])
|
||||
else:
|
||||
tpus.append(tpu)
|
||||
cluster_spec = {self.task_type: tpus}
|
||||
|
||||
if self._coordinator_address:
|
||||
# {1, 2}.a
|
||||
cluster_spec[self._coordinator_name] = [self._coordinator_address]
|
||||
|
||||
return server_lib.ClusterSpec(cluster_spec)
|
||||
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of TPU cores per worker.
|
||||
|
||||
This defaults to 8 for all current TPU configurations, and we do not need
|
||||
to query any remote systems for this.
|
||||
|
||||
Args:
|
||||
session_config: Unused. Not currently necessary to query anything as this
|
||||
number is 8 for all TPU configurations.
|
||||
"""
|
||||
del session_config # Unused. Not necessary to query anything.
|
||||
return 8
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
return self._environment
|
||||
|
||||
def _start_local_server(self):
|
||||
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
|
||||
self._server = server_lib.Server(
|
||||
{
|
||||
'local': ['0.0.0.0:0']
|
||||
}, protocol='grpc', config=None, start=True)
|
||||
# self._server.target is of the form: grpc://ipaddress:port
|
||||
target = compat.as_bytes(self._server.target)
|
||||
splits = target.split(compat.as_bytes(':'))
|
||||
assert len(splits) == 3, self._server.target
|
||||
assert splits[0] == compat.as_bytes('grpc'), self._server.target
|
||||
self._coordinator_port = compat.as_text(splits[2])
|
||||
self._coordinator_address = '%s:%s' % (
|
||||
address, compat.as_text(self._coordinator_port))
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
# TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
|
||||
return self
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
Loading…
Reference in New Issue
Block a user