Support fetching configured runtime version without using Cloud TPU API.

PiperOrigin-RevId: 316888815
Change-Id: I9f840076122e220c80b7b301f2290a6d4f595f1a
This commit is contained in:
Michael Banfield 2020-06-17 08:02:15 -07:00 committed by TensorFlower Gardener
parent f44b07ed4e
commit 3d4ca5a00a
3 changed files with 36 additions and 2 deletions

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import datetime
import json
import logging
import os
import time
@ -48,6 +49,7 @@ _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
_DEFAULT_ENDPOINT_PORT = '8470'
_OOM_EVENT_COOL_TIME_SEC = 90
_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
def _utcnow():
@ -277,6 +279,22 @@ class Client(object):
def runtime_version(self):
"""Return runtime version of the TPU."""
if not self._use_api:
# Fallback on getting version directly from TPU.
url = _VERSION_SWITCHER_ENDPOINT.format(
self.network_endpoints()[0]['ipAddress'])
try:
req = request.Request(url)
resp = request.urlopen(req)
version_details = json.loads(resp.read())
return version_details.get('currentVersion')
except HTTPError as e:
status_code = e.code
if status_code == 404:
return None
else:
raise e
return self._get_tpu_property('tensorflowVersion')
def accelerator_type(self):
@ -350,7 +368,7 @@ class Client(object):
be sent.
"""
ip_address = worker['ipAddress']
url = 'http://{}:8475/requestversion/{}?restartType={}'.format(
url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
ip_address, version, restart_type)
req = request.Request(url, data=b'')
try:

View File

@ -630,6 +630,22 @@ class CloudTpuClientTest(test.TestCase):
'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded'
], sorted(paths))
@mock.patch.object(request, 'urlopen')
def testGetTpuVersion(self, urlopen):
c = client.Client(
tpu='grpc://1.2.3.4:8470')
resp = mock.Mock()
resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}']
urlopen.return_value = resp
self.assertIsNone(c.runtime_version(), 'Missing key should be handled.')
self.assertEqual(
'someVersion', c.runtime_version(), 'Should return configured version.')
paths = [call[0][0].full_url for call in urlopen.call_args_list]
self.assertCountEqual([
'http://1.2.3.4:8475/requestversion',
'http://1.2.3.4:8475/requestversion',
], sorted(paths))
if __name__ == '__main__':
test.main()

View File

@ -18,4 +18,4 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__version__ = "0.9"
__version__ = "0.10"