Create pip scripts for cloud tpu client.

PiperOrigin-RevId: 285260886
Change-Id: I15f4dbc3f1bbab44700b855e19251c7df1c46c31
This commit is contained in:
Michael Banfield 2019-12-12 13:40:25 -08:00 committed by TensorFlower Gardener
parent b18db52708
commit 5364121e85
13 changed files with 345 additions and 130 deletions

View File

@ -60,58 +60,14 @@ py_library(
],
)
py_library(
name = "cloud_tpu_client",
srcs = ["cloud_tpu_client.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"@six_archive//:six",
],
)
tf_py_test(
name = "cloud_tpu_client_py_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY3",
)
tf_py_test(
name = "cloud_tpu_client_py2_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY3",
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":cloud_tpu_client",
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client",
] + tf_additional_rpc_deps(),
)
@ -191,6 +147,7 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client:client",
],
grpc_enabled = True,
main = "tpu_cluster_resolver_test.py",

View File

@ -21,16 +21,20 @@ from __future__ import print_function
import collections
import re
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.warning(
'Falling back to tensorflow client, its recommended to install the cloud '
'tpu client directly with pip install cloud-tpu-client .')
from tensorflow.python.tpu.client import client
def is_running_in_gce():
return True
@ -44,7 +48,7 @@ DeviceDetails = collections.namedtuple(
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
class TPUClusterResolver(ClusterResolver):
class TPUClusterResolver(cluster_resolver.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
@ -135,7 +139,6 @@ class TPUClusterResolver(ClusterResolver):
filled in produce an absolute URL to the discovery document for that
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
this.
**kwargs: Extra keyword arguments passed to CloudTPUClient.
Raises:
ImportError: If the googleapiclient is not installed.
@ -144,7 +147,7 @@ class TPUClusterResolver(ClusterResolver):
Google Cloud environment.
"""
self._cloud_tpu_client = CloudTPUClient(
self._cloud_tpu_client = client.Client(
tpu=tpu,
zone=zone,
project=project,
@ -208,7 +211,7 @@ class TPUClusterResolver(ClusterResolver):
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
return format_master_url(master, 'grpc')
return cluster_resolver.format_master_url(master, 'grpc')
def get_master(self):
return self.master()
@ -277,7 +280,8 @@ class TPUClusterResolver(ClusterResolver):
while True:
try:
device_details = TPUClusterResolver._get_device_dict_and_cores(
get_accelerator_devices(self.master(), config_proto=config_proto))
cluster_resolver.get_accelerator_devices(
self.master(), config_proto=config_proto))
break
except errors.DeadlineExceededError:
error_message = ('Failed to connect to master. The TPU might not be '

View File

@ -25,16 +25,24 @@ from six.moves.urllib.error import URLError
from tensorflow.python import framework
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.eager.context import LogicalDevice
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
mock = test.mock
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.warning(
'Falling back to tensorflow client, its recommended to install the cloud '
'tpu client directly with pip install cloud-tpu-client .')
from tensorflow.python.tpu.client import client
class MockRequestClass(object):
@ -141,7 +149,7 @@ class TPUClusterResolverTest(test.TestCase):
def testIsRunningInGce(self):
self.assertTrue(resolver.is_running_in_gce())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadata(self):
tpu_map = {
@ -174,7 +182,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
tpu_map = {
@ -200,7 +208,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testNotReadyCloudTpu(self):
tpu_map = {
@ -299,7 +307,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testPodResolution(self):
tpu_map = {

View File

@ -0,0 +1,68 @@
# Cloud TPU Client.
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = [
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "client",
srcs = [
"client.py",
"version.py",
],
srcs_version = "PY2AND3",
deps = [
"@six_archive//:six",
],
)
py_library(
name = "client_lib",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":client",
"@six_archive//:six",
],
)
tf_py_test(
name = "client_py_test",
size = "small",
srcs = ["client_test.py"],
additional_deps = [
":client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "client_test.py",
python_version = "PY3",
)
tf_py_test(
name = "client_py2_test",
size = "small",
srcs = ["client_test.py"],
additional_deps = [
":client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "client_test.py",
python_version = "PY2",
)

View File

@ -0,0 +1,21 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.tpu.client.client import Client

View File

@ -23,8 +23,6 @@ import os
from six.moves.urllib import request
from tensorflow.python.util import compat
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from apiclient import discovery # pylint: disable=g-import-not-at-top
@ -49,23 +47,23 @@ def _request_compute_metadata(path):
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
headers={'Metadata-Flavor': 'Google'})
resp = request.urlopen(req)
return compat.as_bytes(resp.read())
return resp.read()
def _environment_var_to_network_endpoints(endpoints):
"""Yields a dict with ip address and port."""
for endpoint in endpoints.split(compat.as_text(',')):
grpc_prefix = compat.as_text('grpc://')
for endpoint in endpoints.split(','):
grpc_prefix = 'grpc://'
if endpoint.startswith(grpc_prefix):
endpoint = endpoint.split(grpc_prefix)[1]
parts = endpoint.split(compat.as_text(':'))
parts = endpoint.split(':')
ip_address = parts[0]
port = _DEFAULT_ENDPOINT_PORT
if len(parts) > 1:
port = parts[1]
yield {
'ipAddress': compat.as_text(ip_address),
'port': compat.as_text(port)
'ipAddress': ip_address,
'port': port
}
@ -79,7 +77,7 @@ def _get_tpu_name(tpu):
return None
class CloudTPUClient(object):
class Client(object):
"""Client for working with the Cloud TPU API.
This client is intended to be used for resolving tpu name to ip addresses.
@ -108,7 +106,7 @@ class CloudTPUClient(object):
if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_text(tpu)
self._tpu = tpu
self._use_api = not tpu.startswith('grpc://')
self._service = service
@ -124,12 +122,11 @@ class CloudTPUClient(object):
if project:
self._project = project
else:
self._project = compat.as_str(
_request_compute_metadata('project/project-id'))
self._project = _request_compute_metadata('project/project-id')
if zone:
self._zone = zone
else:
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
zone_path = _request_compute_metadata('instance/zone')
self._zone = zone_path.split('/')[-1]
self._discovery_url = _environment_discovery_url() or discovery_url
@ -166,7 +163,7 @@ class CloudTPUClient(object):
"""Returns the TPU metadata object from the TPU Get API call."""
try:
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu))
self._project, self._zone, self._tpu)
service = self._tpu_service()
r = service.projects().locations().nodes().get(name=full_name)
return r.execute()
@ -220,7 +217,7 @@ class CloudTPUClient(object):
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state']))
(self._tpu, response['state']))
if 'networkEndpoints' in response:
return response['networkEndpoints']
else:

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for cloud_tpu_client."""
"""Tests for cloud tpu client."""
from __future__ import absolute_import
from __future__ import division
@ -21,8 +21,8 @@ from __future__ import print_function
import os
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.platform import test
from tensorflow.python.tpu.client import client
mock = test.mock
@ -85,19 +85,19 @@ class CloudTpuClientTest(test.TestCase):
def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual('https://{api}.internal/{apiVersion}',
(cloud_tpu_client._environment_discovery_url()))
(client._environment_discovery_url()))
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
list(client._environment_var_to_network_endpoints(
'1.2.3.4:1234')))
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
list(
cloud_tpu_client._environment_var_to_network_endpoints(
client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000')))
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
@ -105,47 +105,47 @@ class CloudTpuClientTest(test.TestCase):
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}],
list(
cloud_tpu_client._environment_var_to_network_endpoints(
client._environment_var_to_network_endpoints(
'1.2.3.4:2000,5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
list(client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '8470'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
list(client._environment_var_to_network_endpoints(
'1.2.3.4:2000,grpc://5.6.7.8')))
def testInitializeNoArguments(self):
with self.assertRaisesRegex(
ValueError, 'Please provide a TPU Name to connect to.'):
cloud_tpu_client.CloudTPUClient()
client.Client()
def testInitializeMultiElementTpuArray(self):
with self.assertRaisesRegex(
NotImplementedError,
'Using multiple TPUs in a single session is not yet implemented'):
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
client.Client(tpu=['multiple', 'elements'])
def assertClientContains(self, client):
self.assertEqual('tpu_name', client._tpu)
self.assertEqual(True, client._use_api)
self.assertEqual(None, client._credentials)
self.assertEqual('test-project', client._project)
self.assertEqual('us-central1-c', client._zone)
self.assertEqual(None, client._discovery_url)
def assertClientContains(self, c):
self.assertEqual('tpu_name', c._tpu)
self.assertEqual(True, c._use_api)
self.assertEqual(None, c._credentials)
self.assertEqual('test-project', c._project)
self.assertEqual('us-central1-c', c._zone)
self.assertEqual(None, c._discovery_url)
self.assertEqual([{
'ipAddress': '10.1.2.3',
'port': '8470'
}], client.network_endpoints())
}], c.network_endpoints())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeNoArgumentsWithEnvironmentVariable(self):
os.environ['TPU_NAME'] = 'tpu_name'
@ -156,11 +156,11 @@ class CloudTpuClientTest(test.TestCase):
'health': 'HEALTHY'
}
}
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client)
self.assertClientContains(c)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeTpuName(self):
tpu_map = {
@ -170,42 +170,42 @@ class CloudTpuClientTest(test.TestCase):
'health': 'HEALTHY'
}
}
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client)
self.assertClientContains(c)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeIpAddress(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
self.assertEqual(False, client._use_api)
self.assertEqual(None, client._service)
self.assertEqual(None, client._credentials)
self.assertEqual(None, client._project)
self.assertEqual(None, client._zone)
self.assertEqual(None, client._discovery_url)
c = client.Client(tpu='grpc://1.2.3.4:8470')
self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
self.assertEqual(False, c._use_api)
self.assertEqual(None, c._service)
self.assertEqual(None, c._credentials)
self.assertEqual(None, c._project)
self.assertEqual(None, c._zone)
self.assertEqual(None, c._discovery_url)
self.assertEqual([{
'ipAddress': '1.2.3.4',
'port': '8470'
}], client.network_endpoints())
}], c.network_endpoints())
def testInitializeWithoutMetadata(self):
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
tpu='tpu_name', project='project', zone='zone')
self.assertEqual('tpu_name', client._tpu)
self.assertEqual(True, client._use_api)
self.assertEqual(None, client._service)
self.assertEqual(None, client._credentials)
self.assertEqual('project', client._project)
self.assertEqual('zone', client._zone)
self.assertEqual(None, client._discovery_url)
self.assertEqual('tpu_name', c._tpu)
self.assertEqual(True, c._use_api)
self.assertEqual(None, c._service)
self.assertEqual(None, c._credentials)
self.assertEqual('project', c._project)
self.assertEqual('zone', c._zone)
self.assertEqual(None, c._discovery_url)
def testRecoverableNoApiAccess(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
self.assertEqual(True, client.recoverable())
c = client.Client(tpu='grpc://1.2.3.4:8470')
self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverableNoState(self):
tpu_map = {
@ -214,11 +214,11 @@ class CloudTpuClientTest(test.TestCase):
'port': '8470',
}
}
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable())
self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverableReady(self):
tpu_map = {
@ -228,11 +228,11 @@ class CloudTpuClientTest(test.TestCase):
'state': 'READY',
}
}
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable())
self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
@mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverablePreempted(self):
tpu_map = {
@ -242,9 +242,9 @@ class CloudTpuClientTest(test.TestCase):
'state': 'PREEMPTED',
}
}
client = cloud_tpu_client.CloudTPUClient(
c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(False, client.recoverable())
self.assertEqual(False, c.recoverable())
if __name__ == '__main__':

View File

@ -0,0 +1,15 @@
# Description:
# Tools for building the Cloud TPU Client pip package.
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
"setup.py",
"//tensorflow/python/tpu/client:client_lib",
],
)

View File

@ -0,0 +1,3 @@
Client responsible for communicating the Cloud TPU API. Released seperately from tensorflow.
https://pypi.org/project/cloud-tpu-client/

View File

@ -0,0 +1,65 @@
#!/usr/bin/env bash
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
set -e
if [ "$(uname)" = "Darwin" ]; then
sedi="sed -i ''"
else
sedi="sed -i"
fi
PACKAGE_NAME="cloud_tpu_client"
PIP_PACKAGE="tensorflow/python/tpu/client/pip_package"
RUNFILES="bazel-bin/tensorflow/python/tpu/client/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow/python/tpu/client"
function main() {
if [ $# -lt 1 ] ; then
echo "No destination dir provided"
exit 1
fi
DEST=$1
TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
cp ${PIP_PACKAGE}/README ${TMPDIR}
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
mkdir ${TMPDIR}/${PACKAGE_NAME}
cp -a ${RUNFILES}/. ${TMPDIR}/${PACKAGE_NAME}/
# Fix the import statements to reflect the copied over path.
find ${TMPDIR}/${PACKAGE_NAME} -name \*.py |
xargs $sedi -e '
s/^from tensorflow.python.tpu.client/from '${PACKAGE_NAME}'/
'
echo $(ls $TMPDIR)
pushd ${TMPDIR}
echo $(date) : "=== Building wheel"
echo $(pwd)
python setup.py bdist_wheel >/dev/null
python3 setup.py bdist_wheel >/dev/null
mkdir -p ${DEST}
cp dist/* ${DEST}
popd
rm -rf ${TMPDIR}
echo $(date) : "=== Output wheel file is in: ${DEST}"
}
main "$@"

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client package."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from cloud_tpu_client.version import __version__
from setuptools import find_packages
from setuptools import setup
setup(
name='cloud-tpu-client',
version=__version__.replace('-', ''),
description='Client for using Cloud TPUs',
long_description='Client for using Cloud TPUs',
url='https://cloud.google.com/tpu/',
author='Google Inc.',
author_email='packages@tensorflow.org',
packages=find_packages(),
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='tensorflow tpu',
install_requires=['google-api-python-client', 'oauth2client']
)

View File

@ -0,0 +1,21 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client version information."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__version__ = "0.2"

View File

@ -21,10 +21,10 @@
function install_ctpu {
PIP_CMD="${1:-pip}"
# TPUClusterResolver has a runtime dependency on these Python libraries when
# resolving a Cloud TPU. It's very likely we want these installed if we're
# TPUClusterResolver has a runtime dependency cloud-tpu-client when
# resolving a Cloud TPU. It's very likely we want this installed if we're
# using CTPU.
"${PIP_CMD}" install --user --upgrade google-api-python-client oauth2client
"${PIP_CMD}" install --user --upgrade cloud-tpu-client
wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu"
chmod a+x ctpu