Add abstract methods/properties to the base Cluster Resolver class.

Also fixes a bug in KubernetesClusterResolver.master() where we were not getting the attribute correctly and added a test for it.

PiperOrigin-RevId: 221660325
This commit is contained in:
Frank Chen 2018-11-15 11:36:11 -08:00 committed by TensorFlower Gardener
parent e0ef2053ac
commit dcf390ede4
4 changed files with 48 additions and 12 deletions

View File

@ -44,6 +44,17 @@ class ClusterResolver(object):
automatically discover and resolve IP addresses for various TensorFlow
workers. This will eventually allow us to automatically recover from
underlying machine failures and scale TensorFlow worker clusters up and down.
Note to Implementors: In addition to these abstract methods, you must also
implement the task_type, task_index, and rpc_layer attributes. You may choose
to implement them either as properties with getters or setters or directly
set the attributes.
- task_type is the name of the server's current named job (e.g. 'worker',
'ps' in a distributed parameterized training job).
- task_index is the ordinal index of the server within the task type.
- rpc_layer is the protocol used by TensorFlow to communicate with other
TensorFlow servers in a distributed environment.
"""
@abc.abstractmethod
@ -60,8 +71,7 @@ class ClusterResolver(object):
management system every time this function is invoked and reconstructing
a cluster_spec, rather than attempting to cache anything.
"""
raise NotImplementedError(
'cluster_spec is not implemented for {}.'.format(self))
raise NotImplementedError()
@abc.abstractmethod
def master(self, task_type=None, task_index=None, rpc_layer=None):
@ -79,7 +89,27 @@ class ClusterResolver(object):
returned is up-to-date at the time to calling this function. This usually
means retrieving the master every time this function is invoked.
"""
raise NotImplementedError('master is not implemented for {}.'.format(self))
raise NotImplementedError()
@abc.abstractmethod
def num_accelerators_per_worker(self, session_config=None):
"""Returns the number of accelerator cores per worker.
This returns the number of accelerator cores (such as GPUs and TPUs)
available per worker. If workers only has CPU cores available, then this
should return 0. This method will query the master for this information
if it is not otherwise known.
Args:
session_config: (Optional) Configuration for starting a new session to
query how many accelerator cores it has.
"""
raise NotImplementedError()
@abc.abstractproperty
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
raise NotImplementedError()
class SimpleClusterResolver(ClusterResolver):

View File

@ -113,9 +113,9 @@ class KubernetesClusterResolver(ClusterResolver):
self.cluster_spec().task_address(task_type, task_index),
rpc_layer or self.rpc_layer)
if self._task_type is not None and self._task_index is not None:
if self.task_type is not None and self.task_index is not None:
return format_master_url(
self.cluster_spec().task_address(self._task_type, self._task_index),
self.cluster_spec().task_address(self.task_type, self.task_index),
rpc_layer or self.rpc_layer)
return ''

View File

@ -118,10 +118,11 @@ class KubernetesClusterResolverTest(test.TestCase):
cluster_resolver = KubernetesClusterResolver(
override_client=_mock_kubernetes_client(
{'job-name=tensorflow': ret}))
cluster_resolver.task_type = 'blah'
cluster_resolver.task_index = 1
self.assertEqual(cluster_resolver.task_type, 'blah')
self.assertEqual(cluster_resolver.task_index, 1)
cluster_resolver.task_type = 'worker'
cluster_resolver.task_index = 0
self.assertEqual(cluster_resolver.task_type, 'worker')
self.assertEqual(cluster_resolver.task_index, 0)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
self.assertEqual(cluster_resolver.master('worker', 2),
'grpc://10.1.2.5:8470')

View File

@ -192,13 +192,13 @@ class TPUClusterResolver(ClusterResolver):
if tpu.startswith('grpc://'):
# Cloud environment, where we are using GRPC to communicate to TPUs.
self.environment = ''
self._environment = ''
elif tpu == 'local' or not tpu:
# Google environment, where the TPU is attached to the host.
self.environment = 'google'
self._environment = 'google'
elif tpu.startswith('/bns'):
# Google environment, where we reach the TPU through BNS.
self.environment = 'google'
self._environment = 'google'
# If TPU is in the Google environment or exists locally, we don't use any
# RPC layer.
@ -398,6 +398,11 @@ class TPUClusterResolver(ClusterResolver):
del session_config # Unused. Not necessary to query anything.
return 8
@property
def environment(self):
"""Returns the current environment which TensorFlow is running in."""
return self._environment
def _start_local_server(self):
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
self._server = server_lib.Server(