From c00122685c2e23213920d62a26e4f239e88a463a Mon Sep 17 00:00:00 2001 From: Jin Young Sohn Date: Tue, 14 Jan 2020 14:37:29 -0800 Subject: [PATCH] Expose additional APIs for cloud-tpu-client PiperOrigin-RevId: 289733231 Change-Id: I5e12da47d3560ca65ae0ce9ad09fff18fcbb2146 --- .../tpu_cluster_resolver_test.py | 8 +++ tensorflow/python/tpu/client/__init__.py | 2 + tensorflow/python/tpu/client/client.py | 30 +++++--- tensorflow/python/tpu/client/client_test.py | 72 ++++++++++++++++++- tensorflow/python/tpu/client/version.py | 2 +- 5 files changed, 103 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py index 6f862c6e1f0..1fad0a3fc95 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py @@ -156,6 +156,7 @@ class TPUClusterResolverTest(test.TestCase): 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', + 'state': 'READY', 'health': 'HEALTHY' } } @@ -189,6 +190,7 @@ class TPUClusterResolverTest(test.TestCase): 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', + 'state': 'READY', 'health': 'HEALTHY' } } @@ -235,6 +237,7 @@ class TPUClusterResolverTest(test.TestCase): 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', + 'state': 'READY', 'health': 'HEALTHY' } } @@ -282,6 +285,7 @@ class TPUClusterResolverTest(test.TestCase): def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [{ 'ipAddress': '10.2.3.4', @@ -312,6 +316,7 @@ class TPUClusterResolverTest(test.TestCase): def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ @@ -361,6 +366,7 @@ class TPUClusterResolverTest(test.TestCase): def testPodResolutionNoCoordinator(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ @@ -504,6 +510,7 @@ class TPUClusterResolverTest(test.TestCase): def testOverrideTaskTypeAndIndexAndGetMaster(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ @@ -626,6 +633,7 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ diff --git a/tensorflow/python/tpu/client/__init__.py b/tensorflow/python/tpu/client/__init__.py index 04d4faf9c68..976f374af63 100644 --- a/tensorflow/python/tpu/client/__init__.py +++ b/tensorflow/python/tpu/client/__init__.py @@ -18,4 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from .version import __version__ + from tensorflow.python.tpu.client.client import Client diff --git a/tensorflow/python/tpu/client/client.py b/tensorflow/python/tpu/client/client.py index fc630ba5191..3c4e65e780a 100644 --- a/tensorflow/python/tpu/client/client.py +++ b/tensorflow/python/tpu/client/client.py @@ -188,6 +188,13 @@ class Client(object): 'doublecheck the tpu argument in the TPUClusterResolver ' 'constructor. Exception: %s' % (self._tpu, e)) + def _get_tpu_property(self, key): + if self._use_api: + metadata = self._fetch_cloud_tpu_metadata() + return metadata.get(key) + + return None + def __enter__(self): self._open = True @@ -206,12 +213,19 @@ class Client(object): def state(self): """Return state of the TPU.""" - if self._use_api: - metadata = self._fetch_cloud_tpu_metadata() - if 'state' in metadata: - return metadata['state'] + return self._get_tpu_property('state') - return None + def health(self): + """Return health of the TPU.""" + return self._get_tpu_property('health') + + def runtime_version(self): + """Return runtime version of the TPU.""" + return self._get_tpu_property('tensorflowVersion') + + def accelerator_type(self): + """Return accelerator type of the TPU.""" + return self._get_tpu_property('acceleratorType') def api_available(self): """Return if the Cloud TPU API is available, if not certain features will not work.""" @@ -229,11 +243,11 @@ class Client(object): """Return a list of tpu endpoints.""" if not self._use_api: return list(_environment_var_to_network_endpoints(self._tpu)) - response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access + response = self._fetch_cloud_tpu_metadata() - if 'state' in response and response['state'] != 'READY': + if response.get('state') != 'READY': raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (self._tpu, response['state'])) + (self._tpu, response.get('state'))) if 'networkEndpoints' in response: return response['networkEndpoints'] else: diff --git a/tensorflow/python/tpu/client/client_test.py b/tensorflow/python/tpu/client/client_test.py index 133e79a2cf7..4a9c0c6ede0 100644 --- a/tensorflow/python/tpu/client/client_test.py +++ b/tensorflow/python/tpu/client/client_test.py @@ -145,6 +145,21 @@ class CloudTpuClientTest(test.TestCase): 'port': '8470' }], c.network_endpoints()) + @mock.patch.object(client, '_request_compute_metadata', + mock_request_compute_metadata) + def testNetworkEndpointsNotReadyWithApi(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + } + } + c = client.Client( + tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) + self.assertRaisesRegex( + RuntimeError, 'TPU .* is not yet ready; state: "None"', + c.network_endpoints) + @mock.patch.object(client, '_request_compute_metadata', mock_request_compute_metadata) def testInitializeNoArgumentsWithEnvironmentVariable(self): @@ -153,7 +168,8 @@ class CloudTpuClientTest(test.TestCase): 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 'ipAddress': '10.1.2.3', 'port': '8470', - 'health': 'HEALTHY' + 'state': 'READY', + 'health': 'HEALTHY', } } c = client.Client( @@ -167,7 +183,8 @@ class CloudTpuClientTest(test.TestCase): 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 'ipAddress': '10.1.2.3', 'port': '8470', - 'health': 'HEALTHY' + 'state': 'READY', + 'health': 'HEALTHY', } } c = client.Client( @@ -246,6 +263,57 @@ class CloudTpuClientTest(test.TestCase): tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(False, c.recoverable()) + @mock.patch.object(client, '_request_compute_metadata', + mock_request_compute_metadata) + def testHealthApi(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'PREEMPTED', + 'health': 'HEALTHY', + 'acceleratorType': 'v3-8', + 'tensorflowVersion': 'nightly', + } + } + c = client.Client( + tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('HEALTHY', c.health()) + + @mock.patch.object(client, '_request_compute_metadata', + mock_request_compute_metadata) + def testRuntimeVersionApi(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'PREEMPTED', + 'health': 'HEALTHY', + 'acceleratorType': 'v3-8', + 'tensorflowVersion': 'nightly', + } + } + c = client.Client( + tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('nightly', c.runtime_version()) + + @mock.patch.object(client, '_request_compute_metadata', + mock_request_compute_metadata) + def testAcceleratorTypeApi(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'PREEMPTED', + 'health': 'HEALTHY', + 'acceleratorType': 'v3-8', + 'tensorflowVersion': 'nightly', + } + } + c = client.Client( + tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('v3-8', c.accelerator_type()) + def testHandlesByteStrings(self): self.assertEqual( client.Client( diff --git a/tensorflow/python/tpu/client/version.py b/tensorflow/python/tpu/client/version.py index f9cc53c8906..d468474fd09 100644 --- a/tensorflow/python/tpu/client/version.py +++ b/tensorflow/python/tpu/client/version.py @@ -18,4 +18,4 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -__version__ = "0.2" +__version__ = "0.5"