Create pip scripts for cloud tpu client.
PiperOrigin-RevId: 285260886 Change-Id: I15f4dbc3f1bbab44700b855e19251c7df1c46c31
This commit is contained in:
parent
b18db52708
commit
5364121e85
@ -60,58 +60,14 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cloud_tpu_client",
|
||||
srcs = ["cloud_tpu_client.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "cloud_tpu_client_py_test",
|
||||
size = "small",
|
||||
srcs = ["cloud_tpu_client_test.py"],
|
||||
additional_deps = [
|
||||
":cloud_tpu_client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "cloud_tpu_client_test.py",
|
||||
python_version = "PY3",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "cloud_tpu_client_py2_test",
|
||||
size = "small",
|
||||
srcs = ["cloud_tpu_client_test.py"],
|
||||
additional_deps = [
|
||||
":cloud_tpu_client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "cloud_tpu_client_test.py",
|
||||
python_version = "PY3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = ["tpu_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cloud_tpu_client",
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python/tpu/client",
|
||||
] + tf_additional_rpc_deps(),
|
||||
)
|
||||
|
||||
@ -191,6 +147,7 @@ tf_py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
"//tensorflow/python/tpu/client:client",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "tpu_cluster_resolver_test.py",
|
||||
|
@ -21,16 +21,20 @@ from __future__ import print_function
|
||||
import collections
|
||||
import re
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
|
||||
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
try:
|
||||
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
'Falling back to tensorflow client, its recommended to install the cloud '
|
||||
'tpu client directly with pip install cloud-tpu-client .')
|
||||
from tensorflow.python.tpu.client import client
|
||||
|
||||
def is_running_in_gce():
|
||||
return True
|
||||
@ -44,7 +48,7 @@ DeviceDetails = collections.namedtuple(
|
||||
|
||||
|
||||
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
||||
class TPUClusterResolver(ClusterResolver):
|
||||
class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
||||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||
@ -135,7 +139,6 @@ class TPUClusterResolver(ClusterResolver):
|
||||
filled in produce an absolute URL to the discovery document for that
|
||||
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
||||
this.
|
||||
**kwargs: Extra keyword arguments passed to CloudTPUClient.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
@ -144,7 +147,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
Google Cloud environment.
|
||||
"""
|
||||
|
||||
self._cloud_tpu_client = CloudTPUClient(
|
||||
self._cloud_tpu_client = client.Client(
|
||||
tpu=tpu,
|
||||
zone=zone,
|
||||
project=project,
|
||||
@ -208,7 +211,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
if not job_tasks:
|
||||
raise ValueError('No TPUs with the specified names exist.')
|
||||
master = job_tasks[0]
|
||||
return format_master_url(master, 'grpc')
|
||||
return cluster_resolver.format_master_url(master, 'grpc')
|
||||
|
||||
def get_master(self):
|
||||
return self.master()
|
||||
@ -277,7 +280,8 @@ class TPUClusterResolver(ClusterResolver):
|
||||
while True:
|
||||
try:
|
||||
device_details = TPUClusterResolver._get_device_dict_and_cores(
|
||||
get_accelerator_devices(self.master(), config_proto=config_proto))
|
||||
cluster_resolver.get_accelerator_devices(
|
||||
self.master(), config_proto=config_proto))
|
||||
break
|
||||
except errors.DeadlineExceededError:
|
||||
error_message = ('Failed to connect to master. The TPU might not be '
|
||||
|
@ -25,16 +25,24 @@ from six.moves.urllib.error import URLError
|
||||
|
||||
from tensorflow.python import framework
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||
from tensorflow.python.eager.context import LogicalDevice
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
mock = test.mock
|
||||
|
||||
try:
|
||||
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
'Falling back to tensorflow client, its recommended to install the cloud '
|
||||
'tpu client directly with pip install cloud-tpu-client .')
|
||||
from tensorflow.python.tpu.client import client
|
||||
|
||||
|
||||
class MockRequestClass(object):
|
||||
|
||||
@ -141,7 +149,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testIsRunningInGce(self):
|
||||
self.assertTrue(resolver.is_running_in_gce())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRetrieveProjectAndZoneFromMetadata(self):
|
||||
tpu_map = {
|
||||
@ -174,7 +182,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
||||
tpu_map = {
|
||||
@ -200,7 +208,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testNotReadyCloudTpu(self):
|
||||
tpu_map = {
|
||||
@ -299,7 +307,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testPodResolution(self):
|
||||
tpu_map = {
|
||||
|
68
tensorflow/python/tpu/client/BUILD
Normal file
68
tensorflow/python/tpu/client/BUILD
Normal file
@ -0,0 +1,68 @@
|
||||
# Cloud TPU Client.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "client",
|
||||
srcs = [
|
||||
"client.py",
|
||||
"version.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "client_lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "client_py_test",
|
||||
size = "small",
|
||||
srcs = ["client_test.py"],
|
||||
additional_deps = [
|
||||
":client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "client_test.py",
|
||||
python_version = "PY3",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "client_py2_test",
|
||||
size = "small",
|
||||
srcs = ["client_test.py"],
|
||||
additional_deps = [
|
||||
":client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "client_test.py",
|
||||
python_version = "PY2",
|
||||
)
|
21
tensorflow/python/tpu/client/__init__.py
Normal file
21
tensorflow/python/tpu/client/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Cloud TPU Client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.tpu.client.client import Client
|
@ -23,8 +23,6 @@ import os
|
||||
|
||||
from six.moves.urllib import request
|
||||
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from apiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
@ -49,23 +47,23 @@ def _request_compute_metadata(path):
|
||||
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = request.urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
return resp.read()
|
||||
|
||||
|
||||
def _environment_var_to_network_endpoints(endpoints):
|
||||
"""Yields a dict with ip address and port."""
|
||||
for endpoint in endpoints.split(compat.as_text(',')):
|
||||
grpc_prefix = compat.as_text('grpc://')
|
||||
for endpoint in endpoints.split(','):
|
||||
grpc_prefix = 'grpc://'
|
||||
if endpoint.startswith(grpc_prefix):
|
||||
endpoint = endpoint.split(grpc_prefix)[1]
|
||||
parts = endpoint.split(compat.as_text(':'))
|
||||
parts = endpoint.split(':')
|
||||
ip_address = parts[0]
|
||||
port = _DEFAULT_ENDPOINT_PORT
|
||||
if len(parts) > 1:
|
||||
port = parts[1]
|
||||
yield {
|
||||
'ipAddress': compat.as_text(ip_address),
|
||||
'port': compat.as_text(port)
|
||||
'ipAddress': ip_address,
|
||||
'port': port
|
||||
}
|
||||
|
||||
|
||||
@ -79,7 +77,7 @@ def _get_tpu_name(tpu):
|
||||
return None
|
||||
|
||||
|
||||
class CloudTPUClient(object):
|
||||
class Client(object):
|
||||
"""Client for working with the Cloud TPU API.
|
||||
|
||||
This client is intended to be used for resolving tpu name to ip addresses.
|
||||
@ -108,7 +106,7 @@ class CloudTPUClient(object):
|
||||
if tpu is None:
|
||||
raise ValueError('Please provide a TPU Name to connect to.')
|
||||
|
||||
self._tpu = compat.as_text(tpu)
|
||||
self._tpu = tpu
|
||||
|
||||
self._use_api = not tpu.startswith('grpc://')
|
||||
self._service = service
|
||||
@ -124,12 +122,11 @@ class CloudTPUClient(object):
|
||||
if project:
|
||||
self._project = project
|
||||
else:
|
||||
self._project = compat.as_str(
|
||||
_request_compute_metadata('project/project-id'))
|
||||
self._project = _request_compute_metadata('project/project-id')
|
||||
if zone:
|
||||
self._zone = zone
|
||||
else:
|
||||
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
|
||||
zone_path = _request_compute_metadata('instance/zone')
|
||||
self._zone = zone_path.split('/')[-1]
|
||||
self._discovery_url = _environment_discovery_url() or discovery_url
|
||||
|
||||
@ -166,7 +163,7 @@ class CloudTPUClient(object):
|
||||
"""Returns the TPU metadata object from the TPU Get API call."""
|
||||
try:
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
self._project, self._zone, self._tpu)
|
||||
service = self._tpu_service()
|
||||
r = service.projects().locations().nodes().get(name=full_name)
|
||||
return r.execute()
|
||||
@ -220,7 +217,7 @@ class CloudTPUClient(object):
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(compat.as_text(self._tpu), response['state']))
|
||||
(self._tpu, response['state']))
|
||||
if 'networkEndpoints' in response:
|
||||
return response['networkEndpoints']
|
||||
else:
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Tests for cloud_tpu_client."""
|
||||
"""Tests for cloud tpu client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,8 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu.client import client
|
||||
|
||||
mock = test.mock
|
||||
|
||||
@ -85,19 +85,19 @@ class CloudTpuClientTest(test.TestCase):
|
||||
def testEnvironmentDiscoveryUrl(self):
|
||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||
(cloud_tpu_client._environment_discovery_url()))
|
||||
(client._environment_discovery_url()))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
list(client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
|
||||
list(
|
||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
client._environment_var_to_network_endpoints(
|
||||
'grpc://1.2.3.4:2000')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
|
||||
@ -105,47 +105,47 @@ class CloudTpuClientTest(test.TestCase):
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||
list(
|
||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:2000,5.6.7.8:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
list(client._environment_var_to_network_endpoints(
|
||||
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '8470'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
list(client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:2000,grpc://5.6.7.8')))
|
||||
|
||||
def testInitializeNoArguments(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Please provide a TPU Name to connect to.'):
|
||||
cloud_tpu_client.CloudTPUClient()
|
||||
client.Client()
|
||||
|
||||
def testInitializeMultiElementTpuArray(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Using multiple TPUs in a single session is not yet implemented'):
|
||||
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
|
||||
client.Client(tpu=['multiple', 'elements'])
|
||||
|
||||
def assertClientContains(self, client):
|
||||
self.assertEqual('tpu_name', client._tpu)
|
||||
self.assertEqual(True, client._use_api)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual('test-project', client._project)
|
||||
self.assertEqual('us-central1-c', client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
def assertClientContains(self, c):
|
||||
self.assertEqual('tpu_name', c._tpu)
|
||||
self.assertEqual(True, c._use_api)
|
||||
self.assertEqual(None, c._credentials)
|
||||
self.assertEqual('test-project', c._project)
|
||||
self.assertEqual('us-central1-c', c._zone)
|
||||
self.assertEqual(None, c._discovery_url)
|
||||
self.assertEqual([{
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470'
|
||||
}], client.network_endpoints())
|
||||
}], c.network_endpoints())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
||||
os.environ['TPU_NAME'] = 'tpu_name'
|
||||
@ -156,11 +156,11 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertClientContains(client)
|
||||
self.assertClientContains(c)
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeTpuName(self):
|
||||
tpu_map = {
|
||||
@ -170,42 +170,42 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertClientContains(client)
|
||||
self.assertClientContains(c)
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeIpAddress(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
|
||||
self.assertEqual(False, client._use_api)
|
||||
self.assertEqual(None, client._service)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual(None, client._project)
|
||||
self.assertEqual(None, client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
c = client.Client(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
|
||||
self.assertEqual(False, c._use_api)
|
||||
self.assertEqual(None, c._service)
|
||||
self.assertEqual(None, c._credentials)
|
||||
self.assertEqual(None, c._project)
|
||||
self.assertEqual(None, c._zone)
|
||||
self.assertEqual(None, c._discovery_url)
|
||||
self.assertEqual([{
|
||||
'ipAddress': '1.2.3.4',
|
||||
'port': '8470'
|
||||
}], client.network_endpoints())
|
||||
}], c.network_endpoints())
|
||||
|
||||
def testInitializeWithoutMetadata(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
tpu='tpu_name', project='project', zone='zone')
|
||||
self.assertEqual('tpu_name', client._tpu)
|
||||
self.assertEqual(True, client._use_api)
|
||||
self.assertEqual(None, client._service)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual('project', client._project)
|
||||
self.assertEqual('zone', client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
self.assertEqual('tpu_name', c._tpu)
|
||||
self.assertEqual(True, c._use_api)
|
||||
self.assertEqual(None, c._service)
|
||||
self.assertEqual(None, c._credentials)
|
||||
self.assertEqual('project', c._project)
|
||||
self.assertEqual('zone', c._zone)
|
||||
self.assertEqual(None, c._discovery_url)
|
||||
|
||||
def testRecoverableNoApiAccess(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual(True, client.recoverable())
|
||||
c = client.Client(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual(True, c.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverableNoState(self):
|
||||
tpu_map = {
|
||||
@ -214,11 +214,11 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'port': '8470',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(True, client.recoverable())
|
||||
self.assertEqual(True, c.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverableReady(self):
|
||||
tpu_map = {
|
||||
@ -228,11 +228,11 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'state': 'READY',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(True, client.recoverable())
|
||||
self.assertEqual(True, c.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
@mock.patch.object(client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverablePreempted(self):
|
||||
tpu_map = {
|
||||
@ -242,9 +242,9 @@ class CloudTpuClientTest(test.TestCase):
|
||||
'state': 'PREEMPTED',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
c = client.Client(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(False, client.recoverable())
|
||||
self.assertEqual(False, c.recoverable())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
15
tensorflow/python/tpu/client/pip_package/BUILD
Normal file
15
tensorflow/python/tpu/client/pip_package/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
# Description:
|
||||
# Tools for building the Cloud TPU Client pip package.
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
sh_binary(
|
||||
name = "build_pip_package",
|
||||
srcs = ["build_pip_package.sh"],
|
||||
data = [
|
||||
"setup.py",
|
||||
"//tensorflow/python/tpu/client:client_lib",
|
||||
],
|
||||
)
|
3
tensorflow/python/tpu/client/pip_package/README
Normal file
3
tensorflow/python/tpu/client/pip_package/README
Normal file
@ -0,0 +1,3 @@
|
||||
Client responsible for communicating the Cloud TPU API. Released seperately from tensorflow.
|
||||
|
||||
https://pypi.org/project/cloud-tpu-client/
|
65
tensorflow/python/tpu/client/pip_package/build_pip_package.sh
Executable file
65
tensorflow/python/tpu/client/pip_package/build_pip_package.sh
Executable file
@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
set -e
|
||||
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sedi="sed -i ''"
|
||||
else
|
||||
sedi="sed -i"
|
||||
fi
|
||||
|
||||
PACKAGE_NAME="cloud_tpu_client"
|
||||
PIP_PACKAGE="tensorflow/python/tpu/client/pip_package"
|
||||
RUNFILES="bazel-bin/tensorflow/python/tpu/client/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow/python/tpu/client"
|
||||
|
||||
function main() {
|
||||
if [ $# -lt 1 ] ; then
|
||||
echo "No destination dir provided"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DEST=$1
|
||||
TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
|
||||
|
||||
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
|
||||
|
||||
cp ${PIP_PACKAGE}/README ${TMPDIR}
|
||||
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
|
||||
mkdir ${TMPDIR}/${PACKAGE_NAME}
|
||||
cp -a ${RUNFILES}/. ${TMPDIR}/${PACKAGE_NAME}/
|
||||
|
||||
# Fix the import statements to reflect the copied over path.
|
||||
find ${TMPDIR}/${PACKAGE_NAME} -name \*.py |
|
||||
xargs $sedi -e '
|
||||
s/^from tensorflow.python.tpu.client/from '${PACKAGE_NAME}'/
|
||||
'
|
||||
echo $(ls $TMPDIR)
|
||||
|
||||
pushd ${TMPDIR}
|
||||
echo $(date) : "=== Building wheel"
|
||||
echo $(pwd)
|
||||
python setup.py bdist_wheel >/dev/null
|
||||
python3 setup.py bdist_wheel >/dev/null
|
||||
mkdir -p ${DEST}
|
||||
cp dist/* ${DEST}
|
||||
popd
|
||||
rm -rf ${TMPDIR}
|
||||
echo $(date) : "=== Output wheel file is in: ${DEST}"
|
||||
}
|
||||
|
||||
main "$@"
|
56
tensorflow/python/tpu/client/pip_package/setup.py
Normal file
56
tensorflow/python/tpu/client/pip_package/setup.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Cloud TPU Client package."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from cloud_tpu_client.version import __version__
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='cloud-tpu-client',
|
||||
version=__version__.replace('-', ''),
|
||||
description='Client for using Cloud TPUs',
|
||||
long_description='Client for using Cloud TPUs',
|
||||
url='https://cloud.google.com/tpu/',
|
||||
author='Google Inc.',
|
||||
author_email='packages@tensorflow.org',
|
||||
packages=find_packages(),
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 2',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Mathematics',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
license='Apache 2.0',
|
||||
keywords='tensorflow tpu',
|
||||
install_requires=['google-api-python-client', 'oauth2client']
|
||||
)
|
21
tensorflow/python/tpu/client/version.py
Normal file
21
tensorflow/python/tpu/client/version.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Cloud TPU Client version information."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__version__ = "0.2"
|
@ -21,10 +21,10 @@
|
||||
function install_ctpu {
|
||||
PIP_CMD="${1:-pip}"
|
||||
|
||||
# TPUClusterResolver has a runtime dependency on these Python libraries when
|
||||
# resolving a Cloud TPU. It's very likely we want these installed if we're
|
||||
# TPUClusterResolver has a runtime dependency cloud-tpu-client when
|
||||
# resolving a Cloud TPU. It's very likely we want this installed if we're
|
||||
# using CTPU.
|
||||
"${PIP_CMD}" install --user --upgrade google-api-python-client oauth2client
|
||||
"${PIP_CMD}" install --user --upgrade cloud-tpu-client
|
||||
|
||||
wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu"
|
||||
chmod a+x ctpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user