Support fetching configured runtime version without using Cloud TPU API.
PiperOrigin-RevId: 316888815 Change-Id: I9f840076122e220c80b7b301f2290a6d4f595f1a
This commit is contained in:
parent
f44b07ed4e
commit
3d4ca5a00a
@ -20,6 +20,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -48,6 +49,7 @@ _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
|||||||
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
||||||
_DEFAULT_ENDPOINT_PORT = '8470'
|
_DEFAULT_ENDPOINT_PORT = '8470'
|
||||||
_OOM_EVENT_COOL_TIME_SEC = 90
|
_OOM_EVENT_COOL_TIME_SEC = 90
|
||||||
|
_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
|
||||||
|
|
||||||
|
|
||||||
def _utcnow():
|
def _utcnow():
|
||||||
@ -277,6 +279,22 @@ class Client(object):
|
|||||||
|
|
||||||
def runtime_version(self):
|
def runtime_version(self):
|
||||||
"""Return runtime version of the TPU."""
|
"""Return runtime version of the TPU."""
|
||||||
|
|
||||||
|
if not self._use_api:
|
||||||
|
# Fallback on getting version directly from TPU.
|
||||||
|
url = _VERSION_SWITCHER_ENDPOINT.format(
|
||||||
|
self.network_endpoints()[0]['ipAddress'])
|
||||||
|
try:
|
||||||
|
req = request.Request(url)
|
||||||
|
resp = request.urlopen(req)
|
||||||
|
version_details = json.loads(resp.read())
|
||||||
|
return version_details.get('currentVersion')
|
||||||
|
except HTTPError as e:
|
||||||
|
status_code = e.code
|
||||||
|
if status_code == 404:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
return self._get_tpu_property('tensorflowVersion')
|
return self._get_tpu_property('tensorflowVersion')
|
||||||
|
|
||||||
def accelerator_type(self):
|
def accelerator_type(self):
|
||||||
@ -350,7 +368,7 @@ class Client(object):
|
|||||||
be sent.
|
be sent.
|
||||||
"""
|
"""
|
||||||
ip_address = worker['ipAddress']
|
ip_address = worker['ipAddress']
|
||||||
url = 'http://{}:8475/requestversion/{}?restartType={}'.format(
|
url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
|
||||||
ip_address, version, restart_type)
|
ip_address, version, restart_type)
|
||||||
req = request.Request(url, data=b'')
|
req = request.Request(url, data=b'')
|
||||||
try:
|
try:
|
||||||
|
@ -630,6 +630,22 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded'
|
'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded'
|
||||||
], sorted(paths))
|
], sorted(paths))
|
||||||
|
|
||||||
|
@mock.patch.object(request, 'urlopen')
|
||||||
|
def testGetTpuVersion(self, urlopen):
|
||||||
|
c = client.Client(
|
||||||
|
tpu='grpc://1.2.3.4:8470')
|
||||||
|
resp = mock.Mock()
|
||||||
|
resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}']
|
||||||
|
urlopen.return_value = resp
|
||||||
|
self.assertIsNone(c.runtime_version(), 'Missing key should be handled.')
|
||||||
|
self.assertEqual(
|
||||||
|
'someVersion', c.runtime_version(), 'Should return configured version.')
|
||||||
|
paths = [call[0][0].full_url for call in urlopen.call_args_list]
|
||||||
|
self.assertCountEqual([
|
||||||
|
'http://1.2.3.4:8475/requestversion',
|
||||||
|
'http://1.2.3.4:8475/requestversion',
|
||||||
|
], sorted(paths))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,4 +18,4 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
__version__ = "0.9"
|
__version__ = "0.10"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user