Create pip scripts for cloud tpu client.

PiperOrigin-RevId: 285260886
Change-Id: I15f4dbc3f1bbab44700b855e19251c7df1c46c31
This commit is contained in:
Michael Banfield 2019-12-12 13:40:25 -08:00 committed by TensorFlower Gardener
parent b18db52708
commit 5364121e85
13 changed files with 345 additions and 130 deletions

View File

@ -60,58 +60,14 @@ py_library(
], ],
) )
py_library(
name = "cloud_tpu_client",
srcs = ["cloud_tpu_client.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"@six_archive//:six",
],
)
tf_py_test(
name = "cloud_tpu_client_py_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY3",
)
tf_py_test(
name = "cloud_tpu_client_py2_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY3",
)
py_library( py_library(
name = "tpu_cluster_resolver_py", name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"], srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":cloud_tpu_client",
":base_cluster_resolver_py", ":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib", "//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client",
] + tf_additional_rpc_deps(), ] + tf_additional_rpc_deps(),
) )
@ -191,6 +147,7 @@ tf_py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib", "//tensorflow/python:training_server_lib",
"//tensorflow/python/tpu/client:client",
], ],
grpc_enabled = True, grpc_enabled = True,
main = "tpu_cluster_resolver_test.py", main = "tpu_cluster_resolver_test.py",

View File

@ -21,16 +21,20 @@ from __future__ import print_function
import collections import collections
import re import re
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient from tensorflow.python.distribute.cluster_resolver import cluster_resolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.warning(
'Falling back to tensorflow client, its recommended to install the cloud '
'tpu client directly with pip install cloud-tpu-client .')
from tensorflow.python.tpu.client import client
def is_running_in_gce(): def is_running_in_gce():
return True return True
@ -44,7 +48,7 @@ DeviceDetails = collections.namedtuple(
@tf_export('distribute.cluster_resolver.TPUClusterResolver') @tf_export('distribute.cluster_resolver.TPUClusterResolver')
class TPUClusterResolver(ClusterResolver): class TPUClusterResolver(cluster_resolver.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs. """Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU This is an implementation of cluster resolvers for the Google Cloud TPU
@ -135,7 +139,6 @@ class TPUClusterResolver(ClusterResolver):
filled in produce an absolute URL to the discovery document for that filled in produce an absolute URL to the discovery document for that
service. The environment variable 'TPU_API_DISCOVERY_URL' will override service. The environment variable 'TPU_API_DISCOVERY_URL' will override
this. this.
**kwargs: Extra keyword arguments passed to CloudTPUClient.
Raises: Raises:
ImportError: If the googleapiclient is not installed. ImportError: If the googleapiclient is not installed.
@ -144,7 +147,7 @@ class TPUClusterResolver(ClusterResolver):
Google Cloud environment. Google Cloud environment.
""" """
self._cloud_tpu_client = CloudTPUClient( self._cloud_tpu_client = client.Client(
tpu=tpu, tpu=tpu,
zone=zone, zone=zone,
project=project, project=project,
@ -208,7 +211,7 @@ class TPUClusterResolver(ClusterResolver):
if not job_tasks: if not job_tasks:
raise ValueError('No TPUs with the specified names exist.') raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0] master = job_tasks[0]
return format_master_url(master, 'grpc') return cluster_resolver.format_master_url(master, 'grpc')
def get_master(self): def get_master(self):
return self.master() return self.master()
@ -277,7 +280,8 @@ class TPUClusterResolver(ClusterResolver):
while True: while True:
try: try:
device_details = TPUClusterResolver._get_device_dict_and_cores( device_details = TPUClusterResolver._get_device_dict_and_cores(
get_accelerator_devices(self.master(), config_proto=config_proto)) cluster_resolver.get_accelerator_devices(
self.master(), config_proto=config_proto))
break break
except errors.DeadlineExceededError: except errors.DeadlineExceededError:
error_message = ('Failed to connect to master. The TPU might not be ' error_message = ('Failed to connect to master. The TPU might not be '

View File

@ -25,16 +25,24 @@ from six.moves.urllib.error import URLError
from tensorflow.python import framework from tensorflow.python import framework
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.eager.context import LogicalDevice from tensorflow.python.eager.context import LogicalDevice
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.util import compat from tensorflow.python.util import compat
mock = test.mock mock = test.mock
try:
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
except ImportError:
logging.warning(
'Falling back to tensorflow client, its recommended to install the cloud '
'tpu client directly with pip install cloud-tpu-client .')
from tensorflow.python.tpu.client import client
class MockRequestClass(object): class MockRequestClass(object):
@ -141,7 +149,7 @@ class TPUClusterResolverTest(test.TestCase):
def testIsRunningInGce(self): def testIsRunningInGce(self):
self.assertTrue(resolver.is_running_in_gce()) self.assertTrue(resolver.is_running_in_gce())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadata(self): def testRetrieveProjectAndZoneFromMetadata(self):
tpu_map = { tpu_map = {
@ -174,7 +182,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470') self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
tpu_map = { tpu_map = {
@ -200,7 +208,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470') self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testNotReadyCloudTpu(self): def testNotReadyCloudTpu(self):
tpu_map = { tpu_map = {
@ -299,7 +307,7 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master()) self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testPodResolution(self): def testPodResolution(self):
tpu_map = { tpu_map = {

View File

@ -0,0 +1,68 @@
# Cloud TPU Client.
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = [
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "client",
srcs = [
"client.py",
"version.py",
],
srcs_version = "PY2AND3",
deps = [
"@six_archive//:six",
],
)
py_library(
name = "client_lib",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":client",
"@six_archive//:six",
],
)
tf_py_test(
name = "client_py_test",
size = "small",
srcs = ["client_test.py"],
additional_deps = [
":client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "client_test.py",
python_version = "PY3",
)
tf_py_test(
name = "client_py2_test",
size = "small",
srcs = ["client_test.py"],
additional_deps = [
":client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "client_test.py",
python_version = "PY2",
)

View File

@ -0,0 +1,21 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.tpu.client.client import Client

View File

@ -23,8 +23,6 @@ import os
from six.moves.urllib import request from six.moves.urllib import request
from tensorflow.python.util import compat
_GOOGLE_API_CLIENT_INSTALLED = True _GOOGLE_API_CLIENT_INSTALLED = True
try: try:
from apiclient import discovery # pylint: disable=g-import-not-at-top from apiclient import discovery # pylint: disable=g-import-not-at-top
@ -49,23 +47,23 @@ def _request_compute_metadata(path):
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path), '%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
headers={'Metadata-Flavor': 'Google'}) headers={'Metadata-Flavor': 'Google'})
resp = request.urlopen(req) resp = request.urlopen(req)
return compat.as_bytes(resp.read()) return resp.read()
def _environment_var_to_network_endpoints(endpoints): def _environment_var_to_network_endpoints(endpoints):
"""Yields a dict with ip address and port.""" """Yields a dict with ip address and port."""
for endpoint in endpoints.split(compat.as_text(',')): for endpoint in endpoints.split(','):
grpc_prefix = compat.as_text('grpc://') grpc_prefix = 'grpc://'
if endpoint.startswith(grpc_prefix): if endpoint.startswith(grpc_prefix):
endpoint = endpoint.split(grpc_prefix)[1] endpoint = endpoint.split(grpc_prefix)[1]
parts = endpoint.split(compat.as_text(':')) parts = endpoint.split(':')
ip_address = parts[0] ip_address = parts[0]
port = _DEFAULT_ENDPOINT_PORT port = _DEFAULT_ENDPOINT_PORT
if len(parts) > 1: if len(parts) > 1:
port = parts[1] port = parts[1]
yield { yield {
'ipAddress': compat.as_text(ip_address), 'ipAddress': ip_address,
'port': compat.as_text(port) 'port': port
} }
@ -79,7 +77,7 @@ def _get_tpu_name(tpu):
return None return None
class CloudTPUClient(object): class Client(object):
"""Client for working with the Cloud TPU API. """Client for working with the Cloud TPU API.
This client is intended to be used for resolving tpu name to ip addresses. This client is intended to be used for resolving tpu name to ip addresses.
@ -108,7 +106,7 @@ class CloudTPUClient(object):
if tpu is None: if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.') raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_text(tpu) self._tpu = tpu
self._use_api = not tpu.startswith('grpc://') self._use_api = not tpu.startswith('grpc://')
self._service = service self._service = service
@ -124,12 +122,11 @@ class CloudTPUClient(object):
if project: if project:
self._project = project self._project = project
else: else:
self._project = compat.as_str( self._project = _request_compute_metadata('project/project-id')
_request_compute_metadata('project/project-id'))
if zone: if zone:
self._zone = zone self._zone = zone
else: else:
zone_path = compat.as_str(_request_compute_metadata('instance/zone')) zone_path = _request_compute_metadata('instance/zone')
self._zone = zone_path.split('/')[-1] self._zone = zone_path.split('/')[-1]
self._discovery_url = _environment_discovery_url() or discovery_url self._discovery_url = _environment_discovery_url() or discovery_url
@ -166,7 +163,7 @@ class CloudTPUClient(object):
"""Returns the TPU metadata object from the TPU Get API call.""" """Returns the TPU metadata object from the TPU Get API call."""
try: try:
full_name = 'projects/%s/locations/%s/nodes/%s' % ( full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu)) self._project, self._zone, self._tpu)
service = self._tpu_service() service = self._tpu_service()
r = service.projects().locations().nodes().get(name=full_name) r = service.projects().locations().nodes().get(name=full_name)
return r.execute() return r.execute()
@ -220,7 +217,7 @@ class CloudTPUClient(object):
if 'state' in response and response['state'] != 'READY': if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state'])) (self._tpu, response['state']))
if 'networkEndpoints' in response: if 'networkEndpoints' in response:
return response['networkEndpoints'] return response['networkEndpoints']
else: else:

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# Lint as: python3 # Lint as: python3
"""Tests for cloud_tpu_client.""" """Tests for cloud tpu client."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -21,8 +21,8 @@ from __future__ import print_function
import os import os
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.tpu.client import client
mock = test.mock mock = test.mock
@ -85,19 +85,19 @@ class CloudTpuClientTest(test.TestCase):
def testEnvironmentDiscoveryUrl(self): def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual('https://{api}.internal/{apiVersion}', self.assertEqual('https://{api}.internal/{apiVersion}',
(cloud_tpu_client._environment_discovery_url())) (client._environment_discovery_url()))
def testEnvironmentVarToNetworkEndpointsSingleIp(self): def testEnvironmentVarToNetworkEndpointsSingleIp(self):
self.assertEqual( self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '1234'}], [{'ipAddress': '1.2.3.4', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints( list(client._environment_var_to_network_endpoints(
'1.2.3.4:1234'))) '1.2.3.4:1234')))
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self): def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
self.assertEqual( self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'}], [{'ipAddress': '1.2.3.4', 'port': '2000'}],
list( list(
cloud_tpu_client._environment_var_to_network_endpoints( client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000'))) 'grpc://1.2.3.4:2000')))
def testEnvironmentVarToNetworkEndpointsMultipleIps(self): def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
@ -105,47 +105,47 @@ class CloudTpuClientTest(test.TestCase):
[{'ipAddress': '1.2.3.4', 'port': '2000'}, [{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}], {'ipAddress': '5.6.7.8', 'port': '1234'}],
list( list(
cloud_tpu_client._environment_var_to_network_endpoints( client._environment_var_to_network_endpoints(
'1.2.3.4:2000,5.6.7.8:1234'))) '1.2.3.4:2000,5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self): def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
self.assertEqual( self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'}, [{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}], {'ipAddress': '5.6.7.8', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints( list(client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234'))) 'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self): def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
self.assertEqual( self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'}, [{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '8470'}], {'ipAddress': '5.6.7.8', 'port': '8470'}],
list(cloud_tpu_client._environment_var_to_network_endpoints( list(client._environment_var_to_network_endpoints(
'1.2.3.4:2000,grpc://5.6.7.8'))) '1.2.3.4:2000,grpc://5.6.7.8')))
def testInitializeNoArguments(self): def testInitializeNoArguments(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, 'Please provide a TPU Name to connect to.'): ValueError, 'Please provide a TPU Name to connect to.'):
cloud_tpu_client.CloudTPUClient() client.Client()
def testInitializeMultiElementTpuArray(self): def testInitializeMultiElementTpuArray(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
NotImplementedError, NotImplementedError,
'Using multiple TPUs in a single session is not yet implemented'): 'Using multiple TPUs in a single session is not yet implemented'):
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements']) client.Client(tpu=['multiple', 'elements'])
def assertClientContains(self, client): def assertClientContains(self, c):
self.assertEqual('tpu_name', client._tpu) self.assertEqual('tpu_name', c._tpu)
self.assertEqual(True, client._use_api) self.assertEqual(True, c._use_api)
self.assertEqual(None, client._credentials) self.assertEqual(None, c._credentials)
self.assertEqual('test-project', client._project) self.assertEqual('test-project', c._project)
self.assertEqual('us-central1-c', client._zone) self.assertEqual('us-central1-c', c._zone)
self.assertEqual(None, client._discovery_url) self.assertEqual(None, c._discovery_url)
self.assertEqual([{ self.assertEqual([{
'ipAddress': '10.1.2.3', 'ipAddress': '10.1.2.3',
'port': '8470' 'port': '8470'
}], client.network_endpoints()) }], c.network_endpoints())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testInitializeNoArgumentsWithEnvironmentVariable(self): def testInitializeNoArgumentsWithEnvironmentVariable(self):
os.environ['TPU_NAME'] = 'tpu_name' os.environ['TPU_NAME'] = 'tpu_name'
@ -156,11 +156,11 @@ class CloudTpuClientTest(test.TestCase):
'health': 'HEALTHY' 'health': 'HEALTHY'
} }
} }
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
service=self.mock_service_client(tpu_map=tpu_map)) service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client) self.assertClientContains(c)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testInitializeTpuName(self): def testInitializeTpuName(self):
tpu_map = { tpu_map = {
@ -170,42 +170,42 @@ class CloudTpuClientTest(test.TestCase):
'health': 'HEALTHY' 'health': 'HEALTHY'
} }
} }
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client) self.assertClientContains(c)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testInitializeIpAddress(self): def testInitializeIpAddress(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470') c = client.Client(tpu='grpc://1.2.3.4:8470')
self.assertEqual('grpc://1.2.3.4:8470', client._tpu) self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
self.assertEqual(False, client._use_api) self.assertEqual(False, c._use_api)
self.assertEqual(None, client._service) self.assertEqual(None, c._service)
self.assertEqual(None, client._credentials) self.assertEqual(None, c._credentials)
self.assertEqual(None, client._project) self.assertEqual(None, c._project)
self.assertEqual(None, client._zone) self.assertEqual(None, c._zone)
self.assertEqual(None, client._discovery_url) self.assertEqual(None, c._discovery_url)
self.assertEqual([{ self.assertEqual([{
'ipAddress': '1.2.3.4', 'ipAddress': '1.2.3.4',
'port': '8470' 'port': '8470'
}], client.network_endpoints()) }], c.network_endpoints())
def testInitializeWithoutMetadata(self): def testInitializeWithoutMetadata(self):
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
tpu='tpu_name', project='project', zone='zone') tpu='tpu_name', project='project', zone='zone')
self.assertEqual('tpu_name', client._tpu) self.assertEqual('tpu_name', c._tpu)
self.assertEqual(True, client._use_api) self.assertEqual(True, c._use_api)
self.assertEqual(None, client._service) self.assertEqual(None, c._service)
self.assertEqual(None, client._credentials) self.assertEqual(None, c._credentials)
self.assertEqual('project', client._project) self.assertEqual('project', c._project)
self.assertEqual('zone', client._zone) self.assertEqual('zone', c._zone)
self.assertEqual(None, client._discovery_url) self.assertEqual(None, c._discovery_url)
def testRecoverableNoApiAccess(self): def testRecoverableNoApiAccess(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470') c = client.Client(tpu='grpc://1.2.3.4:8470')
self.assertEqual(True, client.recoverable()) self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testRecoverableNoState(self): def testRecoverableNoState(self):
tpu_map = { tpu_map = {
@ -214,11 +214,11 @@ class CloudTpuClientTest(test.TestCase):
'port': '8470', 'port': '8470',
} }
} }
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable()) self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testRecoverableReady(self): def testRecoverableReady(self):
tpu_map = { tpu_map = {
@ -228,11 +228,11 @@ class CloudTpuClientTest(test.TestCase):
'state': 'READY', 'state': 'READY',
} }
} }
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable()) self.assertEqual(True, c.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata', @mock.patch.object(client, '_request_compute_metadata',
mock_request_compute_metadata) mock_request_compute_metadata)
def testRecoverablePreempted(self): def testRecoverablePreempted(self):
tpu_map = { tpu_map = {
@ -242,9 +242,9 @@ class CloudTpuClientTest(test.TestCase):
'state': 'PREEMPTED', 'state': 'PREEMPTED',
} }
} }
client = cloud_tpu_client.CloudTPUClient( c = client.Client(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(False, client.recoverable()) self.assertEqual(False, c.recoverable())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,15 @@
# Description:
# Tools for building the Cloud TPU Client pip package.
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
sh_binary(
name = "build_pip_package",
srcs = ["build_pip_package.sh"],
data = [
"setup.py",
"//tensorflow/python/tpu/client:client_lib",
],
)

View File

@ -0,0 +1,3 @@
Client responsible for communicating the Cloud TPU API. Released seperately from tensorflow.
https://pypi.org/project/cloud-tpu-client/

View File

@ -0,0 +1,65 @@
#!/usr/bin/env bash
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
set -e
if [ "$(uname)" = "Darwin" ]; then
sedi="sed -i ''"
else
sedi="sed -i"
fi
PACKAGE_NAME="cloud_tpu_client"
PIP_PACKAGE="tensorflow/python/tpu/client/pip_package"
RUNFILES="bazel-bin/tensorflow/python/tpu/client/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow/python/tpu/client"
function main() {
if [ $# -lt 1 ] ; then
echo "No destination dir provided"
exit 1
fi
DEST=$1
TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
cp ${PIP_PACKAGE}/README ${TMPDIR}
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
mkdir ${TMPDIR}/${PACKAGE_NAME}
cp -a ${RUNFILES}/. ${TMPDIR}/${PACKAGE_NAME}/
# Fix the import statements to reflect the copied over path.
find ${TMPDIR}/${PACKAGE_NAME} -name \*.py |
xargs $sedi -e '
s/^from tensorflow.python.tpu.client/from '${PACKAGE_NAME}'/
'
echo $(ls $TMPDIR)
pushd ${TMPDIR}
echo $(date) : "=== Building wheel"
echo $(pwd)
python setup.py bdist_wheel >/dev/null
python3 setup.py bdist_wheel >/dev/null
mkdir -p ${DEST}
cp dist/* ${DEST}
popd
rm -rf ${TMPDIR}
echo $(date) : "=== Output wheel file is in: ${DEST}"
}
main "$@"

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client package."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from cloud_tpu_client.version import __version__
from setuptools import find_packages
from setuptools import setup
setup(
name='cloud-tpu-client',
version=__version__.replace('-', ''),
description='Client for using Cloud TPUs',
long_description='Client for using Cloud TPUs',
url='https://cloud.google.com/tpu/',
author='Google Inc.',
author_email='packages@tensorflow.org',
packages=find_packages(),
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
keywords='tensorflow tpu',
install_requires=['google-api-python-client', 'oauth2client']
)

View File

@ -0,0 +1,21 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Cloud TPU Client version information."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__version__ = "0.2"

View File

@ -21,10 +21,10 @@
function install_ctpu { function install_ctpu {
PIP_CMD="${1:-pip}" PIP_CMD="${1:-pip}"
# TPUClusterResolver has a runtime dependency on these Python libraries when # TPUClusterResolver has a runtime dependency cloud-tpu-client when
# resolving a Cloud TPU. It's very likely we want these installed if we're # resolving a Cloud TPU. It's very likely we want this installed if we're
# using CTPU. # using CTPU.
"${PIP_CMD}" install --user --upgrade google-api-python-client oauth2client "${PIP_CMD}" install --user --upgrade cloud-tpu-client
wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu" wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu"
chmod a+x ctpu chmod a+x ctpu