Support configuring TPU software version from cloud tpu client.

PiperOrigin-RevId: 294715133
Change-Id: Iae9610a978fe5c9ba1e69407942927f9c1f0a4fc
This commit is contained in:
Michael Banfield 2020-02-12 11:34:56 -08:00 committed by TensorFlower Gardener
parent c283962960
commit 4166342300
6 changed files with 73 additions and 4 deletions

View File

@ -40,6 +40,9 @@ tf_py_test(
grpc_enabled = True,
main = "client_test.py",
python_version = "PY3",
tags = [
"no_oss_py2",
],
deps = [
":client",
"//tensorflow/python:client_testlib",

View File

@ -22,8 +22,10 @@ from __future__ import print_function
import logging
import os
import time
from concurrent import futures
from six.moves.urllib import request
from six.moves.urllib.error import HTTPError
_GOOGLE_API_CLIENT_INSTALLED = True
try:
@ -277,3 +279,37 @@ class Client(object):
time.sleep(interval)
logging.warning('TPU "%s" is healthy.', self.name())
def configure_tpu_version(self, version):
"""Configure TPU software version."""
def configure_worker(worker):
"""Configure individual TPU worker.
Args:
worker: A dict with the field ipAddress where the configure request will
be sent.
"""
ip_address = worker['ipAddress']
url = 'http://{}:8475/requestversion/{}'.format(ip_address, version)
req = request.Request(url, data=b'')
try:
request.urlopen(req)
except HTTPError as e:
status_code = e.code
if status_code == 404:
raise Exception(
'Tensorflow version {} is not available on Cloud TPU, '
'try a previous nightly version or refer to '
'https://cloud.google.com/tpu/docs/release-notes for '
'the latest official version.'.format(version))
else:
raise Exception('Failed to configure worker {}'.format(ip_address))
workers = self.network_endpoints()
with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor:
results = executor.map(configure_worker, workers)
for result in results:
if result:
result.result()

View File

@ -22,6 +22,8 @@ from __future__ import print_function
import os
import time
from six.moves.urllib import request
from tensorflow.python.platform import test
from tensorflow.python.tpu.client import client
@ -394,5 +396,36 @@ class CloudTpuClientTest(test.TestCase):
'Timed out waiting for TPU .* to become healthy'):
c.wait_for_healthy(timeout_s=80, interval=5)
@mock.patch.object(request, 'urlopen')
def testConfigureTpuVersion(self, urlopen):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'state':
'READY',
'networkEndpoints': [
{
'ipAddress': '1.2.3.4'
},
{
'ipAddress': '5.6.7.8'
},
]
}
}
c = client.Client(
tpu='tpu_name',
project='test-project',
zone='us-central1-c',
service=self.mock_service_client(tpu_map=tpu_map))
c.configure_tpu_version('1.15')
paths = [call[0][0].full_url for call in urlopen.call_args_list]
self.assertEqual([
'http://1.2.3.4:8475/requestversion/1.15',
'http://5.6.7.8:8475/requestversion/1.15'
], sorted(paths))
if __name__ == '__main__':
test.main()

View File

@ -53,7 +53,6 @@ function main() {
pushd ${TMPDIR}
echo $(date) : "=== Building wheel"
echo $(pwd)
python setup.py bdist_wheel >/dev/null
python3 setup.py bdist_wheel >/dev/null
mkdir -p ${DEST}
cp dist/* ${DEST}

View File

@ -37,8 +37,6 @@ setup(
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',

View File

@ -18,4 +18,4 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__version__ = "0.5"
__version__ = "0.6"