Adds preliminary support for Cloud TPUs with Cluster Resolvers. This aims to allow users to have a better experienec when specifying one or multiple Cloud TPUs for their training jobs by allowing users to use names rather than IP addresses.
PiperOrigin-RevId: 163393443
This commit is contained in:
parent
e5353c941c
commit
28373cfe70
@ -28,6 +28,7 @@ py_library(
|
||||
deps = [
|
||||
":cluster_resolver_py",
|
||||
":gce_cluster_resolver_py",
|
||||
":tpu_cluster_resolver_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -54,6 +55,18 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = [
|
||||
"python/training/tpu_cluster_resolver.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cluster_resolver_py",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "cluster_resolver_py_test",
|
||||
size = "small",
|
||||
@ -81,3 +94,17 @@ tf_py_test(
|
||||
],
|
||||
main = "python/training/gce_cluster_resolver_test.py",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_cluster_resolver_py_test",
|
||||
size = "small",
|
||||
srcs = ["python/training/tpu_cluster_resolver_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_cluster_resolver_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
main = "python/training/tpu_cluster_resolver_test.py",
|
||||
)
|
||||
|
@ -22,3 +22,4 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of Cluster Resolvers for Cloud TPUs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.training.server_lib import ClusterSpec
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
|
||||
|
||||
class TPUClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||
service. As Cloud TPUs are in alpha, you will need to specify a API definition
|
||||
file for this to consume, in addition to a list of Cloud TPUs in your Google
|
||||
Cloud Platform project.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
api_definition,
|
||||
project,
|
||||
zone,
|
||||
tpu_names,
|
||||
credentials,
|
||||
job_name='tpu_worker',
|
||||
service=None):
|
||||
"""Creates a new TPUClusterResolver object.
|
||||
|
||||
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
|
||||
for the IP addresses and ports of each Cloud TPU listed.
|
||||
|
||||
Args:
|
||||
api_definition: (Alpha only) A copy of the JSON API definitions for
|
||||
Cloud TPUs. This will be removed once Cloud TPU enters beta.
|
||||
project: Name of the GCP project containing Cloud TPUs
|
||||
zone: Zone where the TPUs are located
|
||||
tpu_names: A list of names of the target Cloud TPUs.
|
||||
credentials: GCE Credentials.
|
||||
job_name: Name of the TensorFlow job the TPUs belong to.
|
||||
service: The GCE API object returned by the googleapiclient.discovery
|
||||
function. If you specify a custom service object, then the credentials
|
||||
parameter will be ignored.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
"""
|
||||
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
self._tpu_names = tpu_names
|
||||
self._job_name = job_name
|
||||
if service is None:
|
||||
if not _GOOGLE_API_CLIENT_INSTALLED:
|
||||
raise ImportError('googleapiclient must be installed before using the '
|
||||
'TPU cluster resolver')
|
||||
|
||||
# TODO(frankchn): Remove once Cloud TPU API Definitions are public and
|
||||
# replace with discovery.build('tpu', 'v1')
|
||||
self._service = discovery.build_from_document(api_definition,
|
||||
credentials=credentials)
|
||||
else:
|
||||
self._service = service
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest TPU information.
|
||||
|
||||
We retrieve the information from the GCE APIs every time this method is
|
||||
called.
|
||||
|
||||
Returns:
|
||||
A ClusterSpec containing host information returned from Cloud TPUs.
|
||||
"""
|
||||
worker_list = []
|
||||
|
||||
for tpu_name in self._tpu_names:
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, tpu_name)
|
||||
request = self._service.projects().locations().nodes().get(name=full_name)
|
||||
response = request.execute()
|
||||
|
||||
instance_url = '%s:%s' % (response.ipAddress, response.port)
|
||||
worker_list.append(instance_url)
|
||||
|
||||
return ClusterSpec({self._job_name: worker_list})
|
@ -0,0 +1,111 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TPUClusterResolver."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
mock = test.mock
|
||||
|
||||
|
||||
class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
|
||||
"""Verifies that the ClusterSpec generates the correct proto.
|
||||
|
||||
We are testing this four different ways to ensure that the ClusterSpec
|
||||
returned by the TPUClusterResolver behaves identically to a normal
|
||||
ClusterSpec when passed into the generic ClusterSpec libraries.
|
||||
|
||||
Args:
|
||||
cluster_spec: ClusterSpec returned by the TPUClusterResolver
|
||||
expected_proto: Expected protobuf
|
||||
"""
|
||||
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
|
||||
|
||||
def mock_service_client(
|
||||
self,
|
||||
tpu_map=None):
|
||||
|
||||
if tpu_map is None:
|
||||
tpu_map = {}
|
||||
|
||||
def get_side_effect(name):
|
||||
return tpu_map[name]
|
||||
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.projects.locations.nodes.get.side_effect = get_side_effect
|
||||
return mock_client
|
||||
|
||||
def testSimpleSuccessfulRetrieval(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470'
|
||||
}
|
||||
}
|
||||
|
||||
tpu_cluster_resolver = TPUClusterResolver(
|
||||
project='test-project',
|
||||
zone='us-central1-c',
|
||||
tpu_names=['test-tpu-1'],
|
||||
credentials=None,
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
|
||||
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||
expected_proto = """
|
||||
job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
|
||||
"""
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
|
||||
def testMultipleSuccessfulRetrieval(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470'
|
||||
},
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
|
||||
'ipAddress': '10.4.5.6',
|
||||
'port': '8470'
|
||||
}
|
||||
}
|
||||
|
||||
tpu_cluster_resolver = TPUClusterResolver(
|
||||
project='test-project',
|
||||
zone='us-central1-c',
|
||||
tpu_names=['test-tpu-2', 'test-tpu-1'],
|
||||
credentials=None,
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
|
||||
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||
expected_proto = """
|
||||
job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
|
||||
tasks { key: 1 value: '10.1.2.3:8470' } }
|
||||
"""
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
Loading…
Reference in New Issue
Block a user