Expose additional APIs for cloud-tpu-client
PiperOrigin-RevId: 289733231 Change-Id: I5e12da47d3560ca65ae0ce9ad09fff18fcbb2146
This commit is contained in:
parent
886f2e05bb
commit
c00122685c
@ -156,6 +156,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
@ -189,6 +190,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
@ -235,6 +237,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
@ -282,6 +285,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testNewNetworkEndpointFormat(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY',
|
||||
'networkEndpoints': [{
|
||||
'ipAddress': '10.2.3.4',
|
||||
@ -312,6 +316,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testPodResolution(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'state': 'READY',
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
@ -361,6 +366,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testPodResolutionNoCoordinator(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'state': 'READY',
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
@ -504,6 +510,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testOverrideTaskTypeAndIndexAndGetMaster(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'state': 'READY',
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
@ -626,6 +633,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'state': 'READY',
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
|
@ -18,4 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from .version import __version__
|
||||
|
||||
from tensorflow.python.tpu.client.client import Client
|
||||
|
@ -188,6 +188,13 @@ class Client(object):
|
||||
'doublecheck the tpu argument in the TPUClusterResolver '
|
||||
'constructor. Exception: %s' % (self._tpu, e))
|
||||
|
||||
def _get_tpu_property(self, key):
|
||||
if self._use_api:
|
||||
metadata = self._fetch_cloud_tpu_metadata()
|
||||
return metadata.get(key)
|
||||
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
self._open = True
|
||||
|
||||
@ -206,12 +213,19 @@ class Client(object):
|
||||
|
||||
def state(self):
|
||||
"""Return state of the TPU."""
|
||||
if self._use_api:
|
||||
metadata = self._fetch_cloud_tpu_metadata()
|
||||
if 'state' in metadata:
|
||||
return metadata['state']
|
||||
return self._get_tpu_property('state')
|
||||
|
||||
return None
|
||||
def health(self):
|
||||
"""Return health of the TPU."""
|
||||
return self._get_tpu_property('health')
|
||||
|
||||
def runtime_version(self):
|
||||
"""Return runtime version of the TPU."""
|
||||
return self._get_tpu_property('tensorflowVersion')
|
||||
|
||||
def accelerator_type(self):
|
||||
"""Return accelerator type of the TPU."""
|
||||
return self._get_tpu_property('acceleratorType')
|
||||
|
||||
def api_available(self):
|
||||
"""Return if the Cloud TPU API is available, if not certain features will not work."""
|
||||
@ -229,11 +243,11 @@ class Client(object):
|
||||
"""Return a list of tpu endpoints."""
|
||||
if not self._use_api:
|
||||
return list(_environment_var_to_network_endpoints(self._tpu))
|
||||
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
||||
response = self._fetch_cloud_tpu_metadata()
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
if response.get('state') != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(self._tpu, response['state']))
|
||||
(self._tpu, response.get('state')))
|
||||
if 'networkEndpoints' in response:
|
||||
return response['networkEndpoints']
|
||||
else:
|
||||
|
@ -145,6 +145,21 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'port': '8470'
|
||||
}], c.network_endpoints())
|
||||
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testNetworkEndpointsNotReadyWithApi(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, 'TPU .* is not yet ready; state: "None"',
|
||||
c.network_endpoints)
|
||||
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
||||
@ -153,7 +168,8 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'health': 'HEALTHY'
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
@ -167,7 +183,8 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'health': 'HEALTHY'
|
||||
'state': 'READY',
|
||||
'health': 'HEALTHY',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
@ -246,6 +263,57 @@ class CloudTpuClientTest(test.TestCase):
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(False, c.recoverable())
|
||||
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testHealthApi(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'PREEMPTED',
|
||||
'health': 'HEALTHY',
|
||||
'acceleratorType': 'v3-8',
|
||||
'tensorflowVersion': 'nightly',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual('HEALTHY', c.health())
|
||||
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRuntimeVersionApi(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'PREEMPTED',
|
||||
'health': 'HEALTHY',
|
||||
'acceleratorType': 'v3-8',
|
||||
'tensorflowVersion': 'nightly',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual('nightly', c.runtime_version())
|
||||
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testAcceleratorTypeApi(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'PREEMPTED',
|
||||
'health': 'HEALTHY',
|
||||
'acceleratorType': 'v3-8',
|
||||
'tensorflowVersion': 'nightly',
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual('v3-8', c.accelerator_type())
|
||||
|
||||
def testHandlesByteStrings(self):
|
||||
self.assertEqual(
|
||||
client.Client(
|
||||
|
@ -18,4 +18,4 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__version__ = "0.2"
|
||||
__version__ = "0.5"
|
||||
|
Loading…
Reference in New Issue
Block a user