Support fetching configured runtime version without using Cloud TPU API.
PiperOrigin-RevId: 316888815 Change-Id: I9f840076122e220c80b7b301f2290a6d4f595f1a
This commit is contained in:
parent
f44b07ed4e
commit
3d4ca5a00a
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
@ -48,6 +49,7 @@ _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
||||
_DEFAULT_ENDPOINT_PORT = '8470'
|
||||
_OOM_EVENT_COOL_TIME_SEC = 90
|
||||
_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
|
||||
|
||||
|
||||
def _utcnow():
|
||||
@ -277,6 +279,22 @@ class Client(object):
|
||||
|
||||
def runtime_version(self):
|
||||
"""Return runtime version of the TPU."""
|
||||
|
||||
if not self._use_api:
|
||||
# Fallback on getting version directly from TPU.
|
||||
url = _VERSION_SWITCHER_ENDPOINT.format(
|
||||
self.network_endpoints()[0]['ipAddress'])
|
||||
try:
|
||||
req = request.Request(url)
|
||||
resp = request.urlopen(req)
|
||||
version_details = json.loads(resp.read())
|
||||
return version_details.get('currentVersion')
|
||||
except HTTPError as e:
|
||||
status_code = e.code
|
||||
if status_code == 404:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
return self._get_tpu_property('tensorflowVersion')
|
||||
|
||||
def accelerator_type(self):
|
||||
@ -350,7 +368,7 @@ class Client(object):
|
||||
be sent.
|
||||
"""
|
||||
ip_address = worker['ipAddress']
|
||||
url = 'http://{}:8475/requestversion/{}?restartType={}'.format(
|
||||
url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
|
||||
ip_address, version, restart_type)
|
||||
req = request.Request(url, data=b'')
|
||||
try:
|
||||
|
@ -630,6 +630,22 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded'
|
||||
], sorted(paths))
|
||||
|
||||
@mock.patch.object(request, 'urlopen')
|
||||
def testGetTpuVersion(self, urlopen):
|
||||
c = client.Client(
|
||||
tpu='grpc://1.2.3.4:8470')
|
||||
resp = mock.Mock()
|
||||
resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}']
|
||||
urlopen.return_value = resp
|
||||
self.assertIsNone(c.runtime_version(), 'Missing key should be handled.')
|
||||
self.assertEqual(
|
||||
'someVersion', c.runtime_version(), 'Should return configured version.')
|
||||
paths = [call[0][0].full_url for call in urlopen.call_args_list]
|
||||
self.assertCountEqual([
|
||||
'http://1.2.3.4:8475/requestversion',
|
||||
'http://1.2.3.4:8475/requestversion',
|
||||
], sorted(paths))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -18,4 +18,4 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__version__ = "0.9"
|
||||
__version__ = "0.10"
|
||||
|
Loading…
x
Reference in New Issue
Block a user