Add abstract methods/properties to the base Cluster Resolver class.
Also fixes a bug in KubernetesClusterResolver.master() where we were not getting the attribute correctly and added a test for it. PiperOrigin-RevId: 221660325
This commit is contained in:
parent
e0ef2053ac
commit
dcf390ede4
@ -44,6 +44,17 @@ class ClusterResolver(object):
|
||||
automatically discover and resolve IP addresses for various TensorFlow
|
||||
workers. This will eventually allow us to automatically recover from
|
||||
underlying machine failures and scale TensorFlow worker clusters up and down.
|
||||
|
||||
Note to Implementors: In addition to these abstract methods, you must also
|
||||
implement the task_type, task_index, and rpc_layer attributes. You may choose
|
||||
to implement them either as properties with getters or setters or directly
|
||||
set the attributes.
|
||||
|
||||
- task_type is the name of the server's current named job (e.g. 'worker',
|
||||
'ps' in a distributed parameterized training job).
|
||||
- task_index is the ordinal index of the server within the task type.
|
||||
- rpc_layer is the protocol used by TensorFlow to communicate with other
|
||||
TensorFlow servers in a distributed environment.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -60,8 +71,7 @@ class ClusterResolver(object):
|
||||
management system every time this function is invoked and reconstructing
|
||||
a cluster_spec, rather than attempting to cache anything.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'cluster_spec is not implemented for {}.'.format(self))
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def master(self, task_type=None, task_index=None, rpc_layer=None):
|
||||
@ -79,7 +89,27 @@ class ClusterResolver(object):
|
||||
returned is up-to-date at the time to calling this function. This usually
|
||||
means retrieving the master every time this function is invoked.
|
||||
"""
|
||||
raise NotImplementedError('master is not implemented for {}.'.format(self))
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def num_accelerators_per_worker(self, session_config=None):
|
||||
"""Returns the number of accelerator cores per worker.
|
||||
|
||||
This returns the number of accelerator cores (such as GPUs and TPUs)
|
||||
available per worker. If workers only has CPU cores available, then this
|
||||
should return 0. This method will query the master for this information
|
||||
if it is not otherwise known.
|
||||
|
||||
Args:
|
||||
session_config: (Optional) Configuration for starting a new session to
|
||||
query how many accelerator cores it has.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractproperty
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SimpleClusterResolver(ClusterResolver):
|
||||
|
@ -113,9 +113,9 @@ class KubernetesClusterResolver(ClusterResolver):
|
||||
self.cluster_spec().task_address(task_type, task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
if self._task_type is not None and self._task_index is not None:
|
||||
if self.task_type is not None and self.task_index is not None:
|
||||
return format_master_url(
|
||||
self.cluster_spec().task_address(self._task_type, self._task_index),
|
||||
self.cluster_spec().task_address(self.task_type, self.task_index),
|
||||
rpc_layer or self.rpc_layer)
|
||||
|
||||
return ''
|
||||
|
@ -118,10 +118,11 @@ class KubernetesClusterResolverTest(test.TestCase):
|
||||
cluster_resolver = KubernetesClusterResolver(
|
||||
override_client=_mock_kubernetes_client(
|
||||
{'job-name=tensorflow': ret}))
|
||||
cluster_resolver.task_type = 'blah'
|
||||
cluster_resolver.task_index = 1
|
||||
self.assertEqual(cluster_resolver.task_type, 'blah')
|
||||
self.assertEqual(cluster_resolver.task_index, 1)
|
||||
cluster_resolver.task_type = 'worker'
|
||||
cluster_resolver.task_index = 0
|
||||
self.assertEqual(cluster_resolver.task_type, 'worker')
|
||||
self.assertEqual(cluster_resolver.task_index, 0)
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
self.assertEqual(cluster_resolver.master('worker', 2),
|
||||
'grpc://10.1.2.5:8470')
|
||||
|
||||
|
@ -192,13 +192,13 @@ class TPUClusterResolver(ClusterResolver):
|
||||
|
||||
if tpu.startswith('grpc://'):
|
||||
# Cloud environment, where we are using GRPC to communicate to TPUs.
|
||||
self.environment = ''
|
||||
self._environment = ''
|
||||
elif tpu == 'local' or not tpu:
|
||||
# Google environment, where the TPU is attached to the host.
|
||||
self.environment = 'google'
|
||||
self._environment = 'google'
|
||||
elif tpu.startswith('/bns'):
|
||||
# Google environment, where we reach the TPU through BNS.
|
||||
self.environment = 'google'
|
||||
self._environment = 'google'
|
||||
|
||||
# If TPU is in the Google environment or exists locally, we don't use any
|
||||
# RPC layer.
|
||||
@ -398,6 +398,11 @@ class TPUClusterResolver(ClusterResolver):
|
||||
del session_config # Unused. Not necessary to query anything.
|
||||
return 8
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
"""Returns the current environment which TensorFlow is running in."""
|
||||
return self._environment
|
||||
|
||||
def _start_local_server(self):
|
||||
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
|
||||
self._server = server_lib.Server(
|
||||
|
Loading…
Reference in New Issue
Block a user