Create pip scripts for cloud tpu client.
PiperOrigin-RevId: 285260886 Change-Id: I15f4dbc3f1bbab44700b855e19251c7df1c46c31
This commit is contained in:
parent
b18db52708
commit
5364121e85
@ -60,58 +60,14 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "cloud_tpu_client",
|
|
||||||
srcs = ["cloud_tpu_client.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "cloud_tpu_client_py_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["cloud_tpu_client_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":cloud_tpu_client",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
"//tensorflow/python:training_server_lib",
|
|
||||||
],
|
|
||||||
grpc_enabled = True,
|
|
||||||
main = "cloud_tpu_client_test.py",
|
|
||||||
python_version = "PY3",
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "cloud_tpu_client_py2_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["cloud_tpu_client_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":cloud_tpu_client",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
"//tensorflow/python:training_server_lib",
|
|
||||||
],
|
|
||||||
grpc_enabled = True,
|
|
||||||
main = "cloud_tpu_client_test.py",
|
|
||||||
python_version = "PY3",
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tpu_cluster_resolver_py",
|
name = "tpu_cluster_resolver_py",
|
||||||
srcs = ["tpu_cluster_resolver.py"],
|
srcs = ["tpu_cluster_resolver.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":cloud_tpu_client",
|
|
||||||
":base_cluster_resolver_py",
|
":base_cluster_resolver_py",
|
||||||
"//tensorflow/python:training_server_lib",
|
"//tensorflow/python:training_server_lib",
|
||||||
|
"//tensorflow/python/tpu/client",
|
||||||
] + tf_additional_rpc_deps(),
|
] + tf_additional_rpc_deps(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -191,6 +147,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training_server_lib",
|
"//tensorflow/python:training_server_lib",
|
||||||
|
"//tensorflow/python/tpu/client:client",
|
||||||
],
|
],
|
||||||
grpc_enabled = True,
|
grpc_enabled = True,
|
||||||
main = "tpu_cluster_resolver_test.py",
|
main = "tpu_cluster_resolver_test.py",
|
||||||
|
@ -21,16 +21,20 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
|
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
|
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
logging.warning(
|
||||||
|
'Falling back to tensorflow client, its recommended to install the cloud '
|
||||||
|
'tpu client directly with pip install cloud-tpu-client .')
|
||||||
|
from tensorflow.python.tpu.client import client
|
||||||
|
|
||||||
def is_running_in_gce():
|
def is_running_in_gce():
|
||||||
return True
|
return True
|
||||||
@ -44,7 +48,7 @@ DeviceDetails = collections.namedtuple(
|
|||||||
|
|
||||||
|
|
||||||
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
||||||
class TPUClusterResolver(ClusterResolver):
|
class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
||||||
"""Cluster Resolver for Google Cloud TPUs.
|
"""Cluster Resolver for Google Cloud TPUs.
|
||||||
|
|
||||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||||
@ -135,7 +139,6 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
filled in produce an absolute URL to the discovery document for that
|
filled in produce an absolute URL to the discovery document for that
|
||||||
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
||||||
this.
|
this.
|
||||||
**kwargs: Extra keyword arguments passed to CloudTPUClient.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: If the googleapiclient is not installed.
|
ImportError: If the googleapiclient is not installed.
|
||||||
@ -144,7 +147,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
Google Cloud environment.
|
Google Cloud environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._cloud_tpu_client = CloudTPUClient(
|
self._cloud_tpu_client = client.Client(
|
||||||
tpu=tpu,
|
tpu=tpu,
|
||||||
zone=zone,
|
zone=zone,
|
||||||
project=project,
|
project=project,
|
||||||
@ -208,7 +211,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
if not job_tasks:
|
if not job_tasks:
|
||||||
raise ValueError('No TPUs with the specified names exist.')
|
raise ValueError('No TPUs with the specified names exist.')
|
||||||
master = job_tasks[0]
|
master = job_tasks[0]
|
||||||
return format_master_url(master, 'grpc')
|
return cluster_resolver.format_master_url(master, 'grpc')
|
||||||
|
|
||||||
def get_master(self):
|
def get_master(self):
|
||||||
return self.master()
|
return self.master()
|
||||||
@ -277,7 +280,8 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
device_details = TPUClusterResolver._get_device_dict_and_cores(
|
device_details = TPUClusterResolver._get_device_dict_and_cores(
|
||||||
get_accelerator_devices(self.master(), config_proto=config_proto))
|
cluster_resolver.get_accelerator_devices(
|
||||||
|
self.master(), config_proto=config_proto))
|
||||||
break
|
break
|
||||||
except errors.DeadlineExceededError:
|
except errors.DeadlineExceededError:
|
||||||
error_message = ('Failed to connect to master. The TPU might not be '
|
error_message = ('Failed to connect to master. The TPU might not be '
|
||||||
|
@ -25,16 +25,24 @@ from six.moves.urllib.error import URLError
|
|||||||
|
|
||||||
from tensorflow.python import framework
|
from tensorflow.python import framework
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||||
from tensorflow.python.eager.context import LogicalDevice
|
from tensorflow.python.eager.context import LogicalDevice
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
mock = test.mock
|
mock = test.mock
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cloud_tpu_client import client # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
logging.warning(
|
||||||
|
'Falling back to tensorflow client, its recommended to install the cloud '
|
||||||
|
'tpu client directly with pip install cloud-tpu-client .')
|
||||||
|
from tensorflow.python.tpu.client import client
|
||||||
|
|
||||||
|
|
||||||
class MockRequestClass(object):
|
class MockRequestClass(object):
|
||||||
|
|
||||||
@ -141,7 +149,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
def testIsRunningInGce(self):
|
def testIsRunningInGce(self):
|
||||||
self.assertTrue(resolver.is_running_in_gce())
|
self.assertTrue(resolver.is_running_in_gce())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRetrieveProjectAndZoneFromMetadata(self):
|
def testRetrieveProjectAndZoneFromMetadata(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -174,7 +182,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -200,7 +208,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testNotReadyCloudTpu(self):
|
def testNotReadyCloudTpu(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -299,7 +307,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testPodResolution(self):
|
def testPodResolution(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
|
68
tensorflow/python/tpu/client/BUILD
Normal file
68
tensorflow/python/tpu/client/BUILD
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Cloud TPU Client.
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "client",
|
||||||
|
srcs = [
|
||||||
|
"client.py",
|
||||||
|
"version.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "client_lib",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":client",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "client_py_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["client_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training_server_lib",
|
||||||
|
],
|
||||||
|
grpc_enabled = True,
|
||||||
|
main = "client_test.py",
|
||||||
|
python_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "client_py2_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["client_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training_server_lib",
|
||||||
|
],
|
||||||
|
grpc_enabled = True,
|
||||||
|
main = "client_test.py",
|
||||||
|
python_version = "PY2",
|
||||||
|
)
|
21
tensorflow/python/tpu/client/__init__.py
Normal file
21
tensorflow/python/tpu/client/__init__.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Cloud TPU Client."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.tpu.client.client import Client
|
@ -23,8 +23,6 @@ import os
|
|||||||
|
|
||||||
from six.moves.urllib import request
|
from six.moves.urllib import request
|
||||||
|
|
||||||
from tensorflow.python.util import compat
|
|
||||||
|
|
||||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||||
try:
|
try:
|
||||||
from apiclient import discovery # pylint: disable=g-import-not-at-top
|
from apiclient import discovery # pylint: disable=g-import-not-at-top
|
||||||
@ -49,23 +47,23 @@ def _request_compute_metadata(path):
|
|||||||
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
||||||
headers={'Metadata-Flavor': 'Google'})
|
headers={'Metadata-Flavor': 'Google'})
|
||||||
resp = request.urlopen(req)
|
resp = request.urlopen(req)
|
||||||
return compat.as_bytes(resp.read())
|
return resp.read()
|
||||||
|
|
||||||
|
|
||||||
def _environment_var_to_network_endpoints(endpoints):
|
def _environment_var_to_network_endpoints(endpoints):
|
||||||
"""Yields a dict with ip address and port."""
|
"""Yields a dict with ip address and port."""
|
||||||
for endpoint in endpoints.split(compat.as_text(',')):
|
for endpoint in endpoints.split(','):
|
||||||
grpc_prefix = compat.as_text('grpc://')
|
grpc_prefix = 'grpc://'
|
||||||
if endpoint.startswith(grpc_prefix):
|
if endpoint.startswith(grpc_prefix):
|
||||||
endpoint = endpoint.split(grpc_prefix)[1]
|
endpoint = endpoint.split(grpc_prefix)[1]
|
||||||
parts = endpoint.split(compat.as_text(':'))
|
parts = endpoint.split(':')
|
||||||
ip_address = parts[0]
|
ip_address = parts[0]
|
||||||
port = _DEFAULT_ENDPOINT_PORT
|
port = _DEFAULT_ENDPOINT_PORT
|
||||||
if len(parts) > 1:
|
if len(parts) > 1:
|
||||||
port = parts[1]
|
port = parts[1]
|
||||||
yield {
|
yield {
|
||||||
'ipAddress': compat.as_text(ip_address),
|
'ipAddress': ip_address,
|
||||||
'port': compat.as_text(port)
|
'port': port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +77,7 @@ def _get_tpu_name(tpu):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class CloudTPUClient(object):
|
class Client(object):
|
||||||
"""Client for working with the Cloud TPU API.
|
"""Client for working with the Cloud TPU API.
|
||||||
|
|
||||||
This client is intended to be used for resolving tpu name to ip addresses.
|
This client is intended to be used for resolving tpu name to ip addresses.
|
||||||
@ -108,7 +106,7 @@ class CloudTPUClient(object):
|
|||||||
if tpu is None:
|
if tpu is None:
|
||||||
raise ValueError('Please provide a TPU Name to connect to.')
|
raise ValueError('Please provide a TPU Name to connect to.')
|
||||||
|
|
||||||
self._tpu = compat.as_text(tpu)
|
self._tpu = tpu
|
||||||
|
|
||||||
self._use_api = not tpu.startswith('grpc://')
|
self._use_api = not tpu.startswith('grpc://')
|
||||||
self._service = service
|
self._service = service
|
||||||
@ -124,12 +122,11 @@ class CloudTPUClient(object):
|
|||||||
if project:
|
if project:
|
||||||
self._project = project
|
self._project = project
|
||||||
else:
|
else:
|
||||||
self._project = compat.as_str(
|
self._project = _request_compute_metadata('project/project-id')
|
||||||
_request_compute_metadata('project/project-id'))
|
|
||||||
if zone:
|
if zone:
|
||||||
self._zone = zone
|
self._zone = zone
|
||||||
else:
|
else:
|
||||||
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
|
zone_path = _request_compute_metadata('instance/zone')
|
||||||
self._zone = zone_path.split('/')[-1]
|
self._zone = zone_path.split('/')[-1]
|
||||||
self._discovery_url = _environment_discovery_url() or discovery_url
|
self._discovery_url = _environment_discovery_url() or discovery_url
|
||||||
|
|
||||||
@ -166,7 +163,7 @@ class CloudTPUClient(object):
|
|||||||
"""Returns the TPU metadata object from the TPU Get API call."""
|
"""Returns the TPU metadata object from the TPU Get API call."""
|
||||||
try:
|
try:
|
||||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||||
self._project, self._zone, compat.as_text(self._tpu))
|
self._project, self._zone, self._tpu)
|
||||||
service = self._tpu_service()
|
service = self._tpu_service()
|
||||||
r = service.projects().locations().nodes().get(name=full_name)
|
r = service.projects().locations().nodes().get(name=full_name)
|
||||||
return r.execute()
|
return r.execute()
|
||||||
@ -220,7 +217,7 @@ class CloudTPUClient(object):
|
|||||||
|
|
||||||
if 'state' in response and response['state'] != 'READY':
|
if 'state' in response and response['state'] != 'READY':
|
||||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||||
(compat.as_text(self._tpu), response['state']))
|
(self._tpu, response['state']))
|
||||||
if 'networkEndpoints' in response:
|
if 'networkEndpoints' in response:
|
||||||
return response['networkEndpoints']
|
return response['networkEndpoints']
|
||||||
else:
|
else:
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Tests for cloud_tpu_client."""
|
"""Tests for cloud tpu client."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -21,8 +21,8 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.tpu.client import client
|
||||||
|
|
||||||
mock = test.mock
|
mock = test.mock
|
||||||
|
|
||||||
@ -85,19 +85,19 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
def testEnvironmentDiscoveryUrl(self):
|
def testEnvironmentDiscoveryUrl(self):
|
||||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||||
self.assertEqual('https://{api}.internal/{apiVersion}',
|
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||||
(cloud_tpu_client._environment_discovery_url()))
|
(client._environment_discovery_url()))
|
||||||
|
|
||||||
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
|
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
|
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
|
||||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
list(client._environment_var_to_network_endpoints(
|
||||||
'1.2.3.4:1234')))
|
'1.2.3.4:1234')))
|
||||||
|
|
||||||
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
|
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
|
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
|
||||||
list(
|
list(
|
||||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
client._environment_var_to_network_endpoints(
|
||||||
'grpc://1.2.3.4:2000')))
|
'grpc://1.2.3.4:2000')))
|
||||||
|
|
||||||
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
|
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
|
||||||
@ -105,47 +105,47 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||||
list(
|
list(
|
||||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
client._environment_var_to_network_endpoints(
|
||||||
'1.2.3.4:2000,5.6.7.8:1234')))
|
'1.2.3.4:2000,5.6.7.8:1234')))
|
||||||
|
|
||||||
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
|
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
list(client._environment_var_to_network_endpoints(
|
||||||
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
|
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
|
||||||
|
|
||||||
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
|
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
{'ipAddress': '5.6.7.8', 'port': '8470'}],
|
{'ipAddress': '5.6.7.8', 'port': '8470'}],
|
||||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
list(client._environment_var_to_network_endpoints(
|
||||||
'1.2.3.4:2000,grpc://5.6.7.8')))
|
'1.2.3.4:2000,grpc://5.6.7.8')))
|
||||||
|
|
||||||
def testInitializeNoArguments(self):
|
def testInitializeNoArguments(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'Please provide a TPU Name to connect to.'):
|
ValueError, 'Please provide a TPU Name to connect to.'):
|
||||||
cloud_tpu_client.CloudTPUClient()
|
client.Client()
|
||||||
|
|
||||||
def testInitializeMultiElementTpuArray(self):
|
def testInitializeMultiElementTpuArray(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
NotImplementedError,
|
NotImplementedError,
|
||||||
'Using multiple TPUs in a single session is not yet implemented'):
|
'Using multiple TPUs in a single session is not yet implemented'):
|
||||||
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
|
client.Client(tpu=['multiple', 'elements'])
|
||||||
|
|
||||||
def assertClientContains(self, client):
|
def assertClientContains(self, c):
|
||||||
self.assertEqual('tpu_name', client._tpu)
|
self.assertEqual('tpu_name', c._tpu)
|
||||||
self.assertEqual(True, client._use_api)
|
self.assertEqual(True, c._use_api)
|
||||||
self.assertEqual(None, client._credentials)
|
self.assertEqual(None, c._credentials)
|
||||||
self.assertEqual('test-project', client._project)
|
self.assertEqual('test-project', c._project)
|
||||||
self.assertEqual('us-central1-c', client._zone)
|
self.assertEqual('us-central1-c', c._zone)
|
||||||
self.assertEqual(None, client._discovery_url)
|
self.assertEqual(None, c._discovery_url)
|
||||||
self.assertEqual([{
|
self.assertEqual([{
|
||||||
'ipAddress': '10.1.2.3',
|
'ipAddress': '10.1.2.3',
|
||||||
'port': '8470'
|
'port': '8470'
|
||||||
}], client.network_endpoints())
|
}], c.network_endpoints())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
||||||
os.environ['TPU_NAME'] = 'tpu_name'
|
os.environ['TPU_NAME'] = 'tpu_name'
|
||||||
@ -156,11 +156,11 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'health': 'HEALTHY'
|
'health': 'HEALTHY'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
service=self.mock_service_client(tpu_map=tpu_map))
|
service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertClientContains(client)
|
self.assertClientContains(c)
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testInitializeTpuName(self):
|
def testInitializeTpuName(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -170,42 +170,42 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'health': 'HEALTHY'
|
'health': 'HEALTHY'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertClientContains(client)
|
self.assertClientContains(c)
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testInitializeIpAddress(self):
|
def testInitializeIpAddress(self):
|
||||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
c = client.Client(tpu='grpc://1.2.3.4:8470')
|
||||||
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
|
self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
|
||||||
self.assertEqual(False, client._use_api)
|
self.assertEqual(False, c._use_api)
|
||||||
self.assertEqual(None, client._service)
|
self.assertEqual(None, c._service)
|
||||||
self.assertEqual(None, client._credentials)
|
self.assertEqual(None, c._credentials)
|
||||||
self.assertEqual(None, client._project)
|
self.assertEqual(None, c._project)
|
||||||
self.assertEqual(None, client._zone)
|
self.assertEqual(None, c._zone)
|
||||||
self.assertEqual(None, client._discovery_url)
|
self.assertEqual(None, c._discovery_url)
|
||||||
self.assertEqual([{
|
self.assertEqual([{
|
||||||
'ipAddress': '1.2.3.4',
|
'ipAddress': '1.2.3.4',
|
||||||
'port': '8470'
|
'port': '8470'
|
||||||
}], client.network_endpoints())
|
}], c.network_endpoints())
|
||||||
|
|
||||||
def testInitializeWithoutMetadata(self):
|
def testInitializeWithoutMetadata(self):
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
tpu='tpu_name', project='project', zone='zone')
|
tpu='tpu_name', project='project', zone='zone')
|
||||||
self.assertEqual('tpu_name', client._tpu)
|
self.assertEqual('tpu_name', c._tpu)
|
||||||
self.assertEqual(True, client._use_api)
|
self.assertEqual(True, c._use_api)
|
||||||
self.assertEqual(None, client._service)
|
self.assertEqual(None, c._service)
|
||||||
self.assertEqual(None, client._credentials)
|
self.assertEqual(None, c._credentials)
|
||||||
self.assertEqual('project', client._project)
|
self.assertEqual('project', c._project)
|
||||||
self.assertEqual('zone', client._zone)
|
self.assertEqual('zone', c._zone)
|
||||||
self.assertEqual(None, client._discovery_url)
|
self.assertEqual(None, c._discovery_url)
|
||||||
|
|
||||||
def testRecoverableNoApiAccess(self):
|
def testRecoverableNoApiAccess(self):
|
||||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
c = client.Client(tpu='grpc://1.2.3.4:8470')
|
||||||
self.assertEqual(True, client.recoverable())
|
self.assertEqual(True, c.recoverable())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRecoverableNoState(self):
|
def testRecoverableNoState(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -214,11 +214,11 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'port': '8470',
|
'port': '8470',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertEqual(True, client.recoverable())
|
self.assertEqual(True, c.recoverable())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRecoverableReady(self):
|
def testRecoverableReady(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -228,11 +228,11 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'state': 'READY',
|
'state': 'READY',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertEqual(True, client.recoverable())
|
self.assertEqual(True, c.recoverable())
|
||||||
|
|
||||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
@mock.patch.object(client, '_request_compute_metadata',
|
||||||
mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRecoverablePreempted(self):
|
def testRecoverablePreempted(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
@ -242,9 +242,9 @@ class CloudTpuClientTest(test.TestCase):
|
|||||||
'state': 'PREEMPTED',
|
'state': 'PREEMPTED',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client = cloud_tpu_client.CloudTPUClient(
|
c = client.Client(
|
||||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertEqual(False, client.recoverable())
|
self.assertEqual(False, c.recoverable())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
15
tensorflow/python/tpu/client/pip_package/BUILD
Normal file
15
tensorflow/python/tpu/client/pip_package/BUILD
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Description:
|
||||||
|
# Tools for building the Cloud TPU Client pip package.
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:private"])
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
sh_binary(
|
||||||
|
name = "build_pip_package",
|
||||||
|
srcs = ["build_pip_package.sh"],
|
||||||
|
data = [
|
||||||
|
"setup.py",
|
||||||
|
"//tensorflow/python/tpu/client:client_lib",
|
||||||
|
],
|
||||||
|
)
|
3
tensorflow/python/tpu/client/pip_package/README
Normal file
3
tensorflow/python/tpu/client/pip_package/README
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
Client responsible for communicating the Cloud TPU API. Released seperately from tensorflow.
|
||||||
|
|
||||||
|
https://pypi.org/project/cloud-tpu-client/
|
65
tensorflow/python/tpu/client/pip_package/build_pip_package.sh
Executable file
65
tensorflow/python/tpu/client/pip_package/build_pip_package.sh
Executable file
@ -0,0 +1,65 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
|
sedi="sed -i ''"
|
||||||
|
else
|
||||||
|
sedi="sed -i"
|
||||||
|
fi
|
||||||
|
|
||||||
|
PACKAGE_NAME="cloud_tpu_client"
|
||||||
|
PIP_PACKAGE="tensorflow/python/tpu/client/pip_package"
|
||||||
|
RUNFILES="bazel-bin/tensorflow/python/tpu/client/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow/python/tpu/client"
|
||||||
|
|
||||||
|
function main() {
|
||||||
|
if [ $# -lt 1 ] ; then
|
||||||
|
echo "No destination dir provided"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
DEST=$1
|
||||||
|
TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
|
||||||
|
|
||||||
|
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
|
||||||
|
|
||||||
|
cp ${PIP_PACKAGE}/README ${TMPDIR}
|
||||||
|
cp ${PIP_PACKAGE}/setup.py ${TMPDIR}
|
||||||
|
mkdir ${TMPDIR}/${PACKAGE_NAME}
|
||||||
|
cp -a ${RUNFILES}/. ${TMPDIR}/${PACKAGE_NAME}/
|
||||||
|
|
||||||
|
# Fix the import statements to reflect the copied over path.
|
||||||
|
find ${TMPDIR}/${PACKAGE_NAME} -name \*.py |
|
||||||
|
xargs $sedi -e '
|
||||||
|
s/^from tensorflow.python.tpu.client/from '${PACKAGE_NAME}'/
|
||||||
|
'
|
||||||
|
echo $(ls $TMPDIR)
|
||||||
|
|
||||||
|
pushd ${TMPDIR}
|
||||||
|
echo $(date) : "=== Building wheel"
|
||||||
|
echo $(pwd)
|
||||||
|
python setup.py bdist_wheel >/dev/null
|
||||||
|
python3 setup.py bdist_wheel >/dev/null
|
||||||
|
mkdir -p ${DEST}
|
||||||
|
cp dist/* ${DEST}
|
||||||
|
popd
|
||||||
|
rm -rf ${TMPDIR}
|
||||||
|
echo $(date) : "=== Output wheel file is in: ${DEST}"
|
||||||
|
}
|
||||||
|
|
||||||
|
main "$@"
|
56
tensorflow/python/tpu/client/pip_package/setup.py
Normal file
56
tensorflow/python/tpu/client/pip_package/setup.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Cloud TPU Client package."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from cloud_tpu_client.version import __version__
|
||||||
|
from setuptools import find_packages
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='cloud-tpu-client',
|
||||||
|
version=__version__.replace('-', ''),
|
||||||
|
description='Client for using Cloud TPUs',
|
||||||
|
long_description='Client for using Cloud TPUs',
|
||||||
|
url='https://cloud.google.com/tpu/',
|
||||||
|
author='Google Inc.',
|
||||||
|
author_email='packages@tensorflow.org',
|
||||||
|
packages=find_packages(),
|
||||||
|
classifiers=[
|
||||||
|
'Development Status :: 5 - Production/Stable',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Intended Audience :: Education',
|
||||||
|
'Intended Audience :: Science/Research',
|
||||||
|
'License :: OSI Approved :: Apache Software License',
|
||||||
|
'Programming Language :: Python :: 2',
|
||||||
|
'Programming Language :: Python :: 2.7',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.4',
|
||||||
|
'Programming Language :: Python :: 3.5',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Topic :: Scientific/Engineering',
|
||||||
|
'Topic :: Scientific/Engineering :: Mathematics',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
'Topic :: Software Development',
|
||||||
|
'Topic :: Software Development :: Libraries',
|
||||||
|
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||||
|
],
|
||||||
|
license='Apache 2.0',
|
||||||
|
keywords='tensorflow tpu',
|
||||||
|
install_requires=['google-api-python-client', 'oauth2client']
|
||||||
|
)
|
21
tensorflow/python/tpu/client/version.py
Normal file
21
tensorflow/python/tpu/client/version.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Cloud TPU Client version information."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
__version__ = "0.2"
|
@ -21,10 +21,10 @@
|
|||||||
function install_ctpu {
|
function install_ctpu {
|
||||||
PIP_CMD="${1:-pip}"
|
PIP_CMD="${1:-pip}"
|
||||||
|
|
||||||
# TPUClusterResolver has a runtime dependency on these Python libraries when
|
# TPUClusterResolver has a runtime dependency cloud-tpu-client when
|
||||||
# resolving a Cloud TPU. It's very likely we want these installed if we're
|
# resolving a Cloud TPU. It's very likely we want this installed if we're
|
||||||
# using CTPU.
|
# using CTPU.
|
||||||
"${PIP_CMD}" install --user --upgrade google-api-python-client oauth2client
|
"${PIP_CMD}" install --user --upgrade cloud-tpu-client
|
||||||
|
|
||||||
wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu"
|
wget -nv "https://dl.google.com/cloud_tpu/ctpu/latest/linux/ctpu"
|
||||||
chmod a+x ctpu
|
chmod a+x ctpu
|
||||||
|
Loading…
x
Reference in New Issue
Block a user