Expose additional APIs for cloud-tpu-client

PiperOrigin-RevId: 289733231
Change-Id: I5e12da47d3560ca65ae0ce9ad09fff18fcbb2146
This commit is contained in:
Jin Young Sohn 2020-01-14 14:37:29 -08:00 committed by TensorFlower Gardener
parent 886f2e05bb
commit c00122685c
5 changed files with 103 additions and 11 deletions

View File

@ -156,6 +156,7 @@ class TPUClusterResolverTest(test.TestCase):
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'READY',
'health': 'HEALTHY'
}
}
@ -189,6 +190,7 @@ class TPUClusterResolverTest(test.TestCase):
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'READY',
'health': 'HEALTHY'
}
}
@ -235,6 +237,7 @@ class TPUClusterResolverTest(test.TestCase):
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'READY',
'health': 'HEALTHY'
}
}
@ -282,6 +285,7 @@ class TPUClusterResolverTest(test.TestCase):
def testNewNetworkEndpointFormat(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'state': 'READY',
'health': 'HEALTHY',
'networkEndpoints': [{
'ipAddress': '10.2.3.4',
@ -312,6 +316,7 @@ class TPUClusterResolverTest(test.TestCase):
def testPodResolution(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'state': 'READY',
'health':
'HEALTHY',
'networkEndpoints': [
@ -361,6 +366,7 @@ class TPUClusterResolverTest(test.TestCase):
def testPodResolutionNoCoordinator(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'state': 'READY',
'health':
'HEALTHY',
'networkEndpoints': [
@ -504,6 +510,7 @@ class TPUClusterResolverTest(test.TestCase):
def testOverrideTaskTypeAndIndexAndGetMaster(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'state': 'READY',
'health':
'HEALTHY',
'networkEndpoints': [
@ -626,6 +633,7 @@ class TPUClusterResolverTest(test.TestCase):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'state': 'READY',
'health':
'HEALTHY',
'networkEndpoints': [

View File

@ -18,4 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .version import __version__
from tensorflow.python.tpu.client.client import Client

View File

@ -188,6 +188,13 @@ class Client(object):
'doublecheck the tpu argument in the TPUClusterResolver '
'constructor. Exception: %s' % (self._tpu, e))
def _get_tpu_property(self, key):
if self._use_api:
metadata = self._fetch_cloud_tpu_metadata()
return metadata.get(key)
return None
def __enter__(self):
self._open = True
@ -206,12 +213,19 @@ class Client(object):
def state(self):
"""Return state of the TPU."""
if self._use_api:
metadata = self._fetch_cloud_tpu_metadata()
if 'state' in metadata:
return metadata['state']
return self._get_tpu_property('state')
return None
def health(self):
"""Return health of the TPU."""
return self._get_tpu_property('health')
def runtime_version(self):
"""Return runtime version of the TPU."""
return self._get_tpu_property('tensorflowVersion')
def accelerator_type(self):
"""Return accelerator type of the TPU."""
return self._get_tpu_property('acceleratorType')
def api_available(self):
"""Return if the Cloud TPU API is available, if not certain features will not work."""
@ -229,11 +243,11 @@ class Client(object):
"""Return a list of tpu endpoints."""
if not self._use_api:
return list(_environment_var_to_network_endpoints(self._tpu))
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
response = self._fetch_cloud_tpu_metadata()
if 'state' in response and response['state'] != 'READY':
if response.get('state') != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(self._tpu, response['state']))
(self._tpu, response.get('state')))
if 'networkEndpoints' in response:
return response['networkEndpoints']
else:

View File

@ -145,6 +145,21 @@ class CloudTpuClientTest(test.TestCase):
'port': '8470'
}], c.network_endpoints())
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testNetworkEndpointsNotReadyWithApi(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
}
}
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertRaisesRegex(
RuntimeError, 'TPU .* is not yet ready; state: "None"',
c.network_endpoints)
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeNoArgumentsWithEnvironmentVariable(self):
@ -153,7 +168,8 @@ class CloudTpuClientTest(test.TestCase):
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
'state': 'READY',
'health': 'HEALTHY',
}
}
c = client.Client(
@ -167,7 +183,8 @@ class CloudTpuClientTest(test.TestCase):
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
'state': 'READY',
'health': 'HEALTHY',
}
}
c = client.Client(
@ -246,6 +263,57 @@ class CloudTpuClientTest(test.TestCase):
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(False, c.recoverable())
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testHealthApi(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'PREEMPTED',
'health': 'HEALTHY',
'acceleratorType': 'v3-8',
'tensorflowVersion': 'nightly',
}
}
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual('HEALTHY', c.health())
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRuntimeVersionApi(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'PREEMPTED',
'health': 'HEALTHY',
'acceleratorType': 'v3-8',
'tensorflowVersion': 'nightly',
}
}
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual('nightly', c.runtime_version())
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testAcceleratorTypeApi(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'PREEMPTED',
'health': 'HEALTHY',
'acceleratorType': 'v3-8',
'tensorflowVersion': 'nightly',
}
}
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual('v3-8', c.accelerator_type())
def testHandlesByteStrings(self):
self.assertEqual(
client.Client(

View File

@ -18,4 +18,4 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__version__ = "0.2"
__version__ = "0.5"