Support configuring TPU software version from cloud tpu client.
PiperOrigin-RevId: 294715133 Change-Id: Iae9610a978fe5c9ba1e69407942927f9c1f0a4fc
This commit is contained in:
parent
c283962960
commit
4166342300
@ -40,6 +40,9 @@ tf_py_test(
|
||||
grpc_enabled = True,
|
||||
main = "client_test.py",
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_oss_py2",
|
||||
],
|
||||
deps = [
|
||||
":client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -22,8 +22,10 @@ from __future__ import print_function
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from concurrent import futures
|
||||
|
||||
from six.moves.urllib import request
|
||||
from six.moves.urllib.error import HTTPError
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
@ -277,3 +279,37 @@ class Client(object):
|
||||
time.sleep(interval)
|
||||
|
||||
logging.warning('TPU "%s" is healthy.', self.name())
|
||||
|
||||
def configure_tpu_version(self, version):
|
||||
"""Configure TPU software version."""
|
||||
|
||||
def configure_worker(worker):
|
||||
"""Configure individual TPU worker.
|
||||
|
||||
Args:
|
||||
worker: A dict with the field ipAddress where the configure request will
|
||||
be sent.
|
||||
"""
|
||||
ip_address = worker['ipAddress']
|
||||
url = 'http://{}:8475/requestversion/{}'.format(ip_address, version)
|
||||
req = request.Request(url, data=b'')
|
||||
try:
|
||||
request.urlopen(req)
|
||||
except HTTPError as e:
|
||||
status_code = e.code
|
||||
if status_code == 404:
|
||||
raise Exception(
|
||||
'Tensorflow version {} is not available on Cloud TPU, '
|
||||
'try a previous nightly version or refer to '
|
||||
'https://cloud.google.com/tpu/docs/release-notes for '
|
||||
'the latest official version.'.format(version))
|
||||
else:
|
||||
raise Exception('Failed to configure worker {}'.format(ip_address))
|
||||
|
||||
workers = self.network_endpoints()
|
||||
|
||||
with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor:
|
||||
results = executor.map(configure_worker, workers)
|
||||
for result in results:
|
||||
if result:
|
||||
result.result()
|
||||
|
@ -22,6 +22,8 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
|
||||
from six.moves.urllib import request
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu.client import client
|
||||
|
||||
@ -394,5 +396,36 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'Timed out waiting for TPU .* to become healthy'):
|
||||
c.wait_for_healthy(timeout_s=80, interval=5)
|
||||
|
||||
@mock.patch.object(request, 'urlopen')
|
||||
def testConfigureTpuVersion(self, urlopen):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'state':
|
||||
'READY',
|
||||
'networkEndpoints': [
|
||||
{
|
||||
'ipAddress': '1.2.3.4'
|
||||
},
|
||||
{
|
||||
'ipAddress': '5.6.7.8'
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
c = client.Client(
|
||||
tpu='tpu_name',
|
||||
project='test-project',
|
||||
zone='us-central1-c',
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
c.configure_tpu_version('1.15')
|
||||
|
||||
paths = [call[0][0].full_url for call in urlopen.call_args_list]
|
||||
|
||||
self.assertEqual([
|
||||
'http://1.2.3.4:8475/requestversion/1.15',
|
||||
'http://5.6.7.8:8475/requestversion/1.15'
|
||||
], sorted(paths))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -53,7 +53,6 @@ function main() {
|
||||
pushd ${TMPDIR}
|
||||
echo $(date) : "=== Building wheel"
|
||||
echo $(pwd)
|
||||
python setup.py bdist_wheel >/dev/null
|
||||
python3 setup.py bdist_wheel >/dev/null
|
||||
mkdir -p ${DEST}
|
||||
cp dist/* ${DEST}
|
||||
|
@ -37,8 +37,6 @@ setup(
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
|
@ -18,4 +18,4 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__version__ = "0.5"
|
||||
__version__ = "0.6"
|
||||
|
Loading…
Reference in New Issue
Block a user