Refactor Tpu Cluster Resolver.

PiperOrigin-RevId: 282612898
Change-Id: I826f45bd86f8d986631efe4c36cb87321993087a
This commit is contained in:
Michael Banfield 2019-11-26 12:04:11 -08:00 committed by TensorFlower Gardener
parent 602e65243d
commit 7d801fe575
6 changed files with 668 additions and 352 deletions

View File

@ -60,11 +60,56 @@ py_library(
],
)
py_library(
name = "cloud_tpu_client",
srcs = ["cloud_tpu_client.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"@six_archive//:six",
],
)
tf_py_test(
name = "cloud_tpu_client_py_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY3",
)
tf_py_test(
name = "cloud_tpu_client_py2_test",
size = "small",
srcs = ["cloud_tpu_client_test.py"],
additional_deps = [
":cloud_tpu_client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:training_server_lib",
],
grpc_enabled = True,
main = "cloud_tpu_client_test.py",
python_version = "PY2",
)
py_library(
name = "tpu_cluster_resolver_py",
srcs = ["tpu_cluster_resolver.py"],
srcs_version = "PY2AND3",
deps = [
":cloud_tpu_client",
":base_cluster_resolver_py",
"//tensorflow/python:training_server_lib",
] + tf_additional_rpc_deps(),

View File

@ -0,0 +1,227 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Cloud TPU Client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves.urllib import request
from tensorflow.python.util import compat
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from apiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
_DEFAULT_ENDPOINT_PORT = '8470'
def _environment_discovery_url():
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
def _request_compute_metadata(path):
req = request.Request(
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
headers={'Metadata-Flavor': 'Google'})
resp = request.urlopen(req)
return compat.as_bytes(resp.read())
def _environment_var_to_network_endpoints(endpoints):
"""Yields a dict with ip address and port."""
for endpoint in endpoints.split(compat.as_text(',')):
grpc_prefix = compat.as_text('grpc://')
if endpoint.startswith(grpc_prefix):
endpoint = endpoint.split(grpc_prefix)[1]
parts = endpoint.split(compat.as_text(':'))
ip_address = parts[0]
port = _DEFAULT_ENDPOINT_PORT
if len(parts) > 1:
port = parts[1]
yield {
'ipAddress': compat.as_text(ip_address),
'port': compat.as_text(port)
}
def _get_tpu_name(tpu):
if tpu:
return tpu
for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
if e in os.environ:
return os.environ[e]
return None
class CloudTPUClient(object):
"""Client for working with the Cloud TPU API.
This client is intended to be used for resolving tpu name to ip addresses.
It's recommended to use this library as a contextlib to utilize all
functionality.
"""
def __init__(self,
tpu=None,
zone=None,
project=None,
credentials='default',
service=None,
discovery_url=None):
if isinstance(tpu, list):
if not tpu:
raise ValueError('At least one TPU must be specified.')
if len(tpu) != 1:
raise NotImplementedError(
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
tpu = _get_tpu_name(tpu)
if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_text(tpu)
self._use_api = not tpu.startswith('grpc://')
self._service = service
self._credentials = None
self._project = None
self._zone = None
self._discovery_url = None
if self._use_api:
if credentials != 'default':
self._credentials = credentials
# Automaically detect project and zone if unspecified.
if project:
self._project = project
else:
self._project = compat.as_str(
_request_compute_metadata('project/project-id'))
if zone:
self._zone = zone
else:
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
self._zone = zone_path.split('/')[-1]
self._discovery_url = _environment_discovery_url() or discovery_url
def _tpu_service(self):
"""Creates a new Cloud TPU API object.
This works around an issue where the underlying HTTP connection sometimes
times out when the script has been running for too long. Other methods in
this object call this method to get a new API object whenever they need
to communicate with the Cloud API.
Returns:
A Google Cloud TPU API object.
"""
if self._service:
return self._service
credentials = self._credentials
if credentials is None or credentials == 'default':
credentials = GoogleCredentials.get_application_default()
if self._discovery_url:
return discovery.build(
'tpu',
'v1',
credentials=credentials,
discoveryServiceUrl=self._discovery_url,
cache_discovery=False)
else:
return discovery.build(
'tpu', 'v1', credentials=credentials, cache_discovery=False)
def _fetch_cloud_tpu_metadata(self):
"""Returns the TPU metadata object from the TPU Get API call."""
try:
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu))
service = self._tpu_service()
r = service.projects().locations().nodes().get(name=full_name)
return r.execute()
except Exception as e:
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
'doublecheck the tpu argument in the TPUClusterResolver '
'constructor. Exception: %s' % (self._tpu, e))
def __enter__(self):
self._open = True
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
del type, value, traceback
def recoverable(self):
"""Returns true if the TPU is in a state where training should eventually resume.
If false the TPU is in a unrecoverable state and should be recreated.
"""
state = self.state()
if state and state in ['TERMINATED', 'PREEMPTED']:
return False
return True
def state(self):
"""Return state of the TPU."""
if self._use_api:
metadata = self._fetch_cloud_tpu_metadata()
if 'state' in metadata:
return metadata['state']
return None
def api_available(self):
"""Return if the Cloud TPU API is available, if not certain features will not work."""
return self._use_api
def name(self):
"""Return the name of the tpu, or the ip address if name is not provided."""
return self._tpu
def get_local_ip(self):
"""Return the local ip address of the Google Cloud VM the workload is running on."""
return _request_compute_metadata('instance/network-interfaces/0/ip')
def network_endpoints(self):
"""Return a list of tpu endpoints."""
if not self._use_api:
return list(_environment_var_to_network_endpoints(self._tpu))
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state']))
if 'networkEndpoints' in response:
return response['networkEndpoints']
else:
return [{'ipAddress': response['ipAddress'], 'port': response['port']}]

View File

@ -0,0 +1,251 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Tests for cloud_tpu_client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.platform import test
mock = test.mock
def mock_request_compute_metadata(path):
if path == 'project/project-id':
return 'test-project'
elif path == 'instance/zone':
return 'projects/test-project/locations/us-central1-c'
elif path == 'instance/network-interfaces/0/ip':
return '10.128.1.2'
return ''
class MockRequestClass(object):
def __init__(self, name, tpu_map):
self._name = name
self._tpu_map = tpu_map
def execute(self):
if self._name in self._tpu_map:
return self._tpu_map[self._name]
else:
raise KeyError('Resource %s was not found' % self._name)
class MockNodeClass(object):
def __init__(self, tpu_map):
self._tpu_map = tpu_map
def get(self, name):
return MockRequestClass(name, self._tpu_map)
class CloudTpuClientTest(test.TestCase):
def setUp(self):
super(CloudTpuClientTest, self).setUp()
if 'TPU_API_DISCOVERY_URL' in os.environ:
del os.environ['TPU_API_DISCOVERY_URL']
if 'TPU_NAME' in os.environ:
del os.environ['TPU_NAME']
def mock_service_client(self, tpu_map=None):
if tpu_map is None:
tpu_map = {}
mock_locations = mock.MagicMock()
mock_locations.nodes.return_value = MockNodeClass(tpu_map)
mock_project = mock.MagicMock()
mock_project.locations.return_value = mock_locations
mock_client = mock.MagicMock()
mock_client.projects.return_value = mock_project
return mock_client
def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual('https://{api}.internal/{apiVersion}',
(cloud_tpu_client._environment_discovery_url()))
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
'1.2.3.4:1234')))
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
list(
cloud_tpu_client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000')))
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}],
list(
cloud_tpu_client._environment_var_to_network_endpoints(
'1.2.3.4:2000,5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '1234'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
self.assertEqual(
[{'ipAddress': '1.2.3.4', 'port': '2000'},
{'ipAddress': '5.6.7.8', 'port': '8470'}],
list(cloud_tpu_client._environment_var_to_network_endpoints(
'1.2.3.4:2000,grpc://5.6.7.8')))
def testInitializeNoArguments(self):
with self.assertRaisesRegex(
ValueError, 'Please provide a TPU Name to connect to.'):
cloud_tpu_client.CloudTPUClient()
def testInitializeMultiElementTpuArray(self):
with self.assertRaisesRegex(
NotImplementedError,
'Using multiple TPUs in a single session is not yet implemented'):
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
def assertClientContains(self, client):
self.assertEqual('tpu_name', client._tpu)
self.assertEqual(True, client._use_api)
self.assertEqual(None, client._credentials)
self.assertEqual('test-project', client._project)
self.assertEqual('us-central1-c', client._zone)
self.assertEqual(None, client._discovery_url)
self.assertEqual([{
'ipAddress': '10.1.2.3',
'port': '8470'
}], client.network_endpoints())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeNoArgumentsWithEnvironmentVariable(self):
os.environ['TPU_NAME'] = 'tpu_name'
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
}
}
client = cloud_tpu_client.CloudTPUClient(
service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeTpuName(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
}
}
client = cloud_tpu_client.CloudTPUClient(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertClientContains(client)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testInitializeIpAddress(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
self.assertEqual(False, client._use_api)
self.assertEqual(None, client._service)
self.assertEqual(None, client._credentials)
self.assertEqual(None, client._project)
self.assertEqual(None, client._zone)
self.assertEqual(None, client._discovery_url)
self.assertEqual([{
'ipAddress': '1.2.3.4',
'port': '8470'
}], client.network_endpoints())
def testInitializeWithoutMetadata(self):
client = cloud_tpu_client.CloudTPUClient(
tpu='tpu_name', project='project', zone='zone')
self.assertEqual('tpu_name', client._tpu)
self.assertEqual(True, client._use_api)
self.assertEqual(None, client._service)
self.assertEqual(None, client._credentials)
self.assertEqual('project', client._project)
self.assertEqual('zone', client._zone)
self.assertEqual(None, client._discovery_url)
def testRecoverableNoApiAccess(self):
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
self.assertEqual(True, client.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverableNoState(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
}
}
client = cloud_tpu_client.CloudTPUClient(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverableReady(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'READY',
}
}
client = cloud_tpu_client.CloudTPUClient(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(True, client.recoverable())
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRecoverablePreempted(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'PREEMPTED',
}
}
client = cloud_tpu_client.CloudTPUClient(
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(False, client.recoverable())
if __name__ == '__main__':
test.main()

View File

@ -19,61 +19,30 @@ from __future__ import division
from __future__ import print_function
import collections
import os
import re
from six.moves import urllib
from six.moves.urllib.error import URLError
from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
_GOOGLE_API_CLIENT_INSTALLED = True
try:
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
except ImportError:
_GOOGLE_API_CLIENT_INSTALLED = False
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
def is_running_in_gce():
return True
_TPU_DEVICE_REGEX = re.compile(
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
_TPU_CONN_RETRIES = 120
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
DeviceDetails = collections.namedtuple(
'DeviceDetails', ['device_map', 'total_cores'])
def is_running_in_gce():
"""Checks for GCE presence by attempting to query the metadata service."""
try:
req = Request(
'%s/computeMetadata/v1' % _GCE_METADATA_ENDPOINT,
headers={'Metadata-Flavor': 'Google'})
resp = urllib.request.urlopen(req, timeout=1)
info = resp.info()
if 'Metadata-Flavor' in info and info['Metadata-Flavor'] == 'Google':
return True
except URLError:
pass
return False
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
class TPUClusterResolver(ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
@ -89,53 +58,6 @@ class TPUClusterResolver(ClusterResolver):
Google internal
"""
def _tpu_service(self):
"""Creates a new Cloud TPU API object.
This works around an issue where the underlying HTTP connection sometimes
times out when the script has been running for too long. Other methods in
this object call this method to get a new API object whenever they need
to communicate with the Cloud API.
Returns:
A Google Cloud TPU API object.
"""
if self._service:
return self._service
credentials = self._credentials
if credentials is None or credentials == 'default':
credentials = GoogleCredentials.get_application_default()
if self._discovery_url:
return discovery.build(
'tpu',
'v1',
credentials=credentials,
discoveryServiceUrl=self._discovery_url,
cache_discovery=False)
else:
return discovery.build(
'tpu', 'v1', credentials=credentials, cache_discovery=False)
def _request_compute_metadata(self, path):
req = Request('%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
headers={'Metadata-Flavor': 'Google'})
resp = urlopen(req)
return compat.as_bytes(resp.read())
def _is_local_tpu(self):
return (
self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local'))
def _should_resolve(self):
if isinstance(self._should_resolve_override, bool):
return self._should_resolve_override
else:
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
self._is_local_tpu())
@staticmethod
def _get_device_dict_and_cores(devices):
"""Returns a dict of hosts to cores and total cores given devices names.
@ -168,25 +90,6 @@ class TPUClusterResolver(ClusterResolver):
'should never happen. Devices: {}'.format(device_dict))
return num_cores_per_host_set.pop()
@staticmethod
def _in_gke():
"""When running in GKE, the environment variable will be set."""
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
def _gke_endpoints():
return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _env_var_fallback():
if _DEFAULT_ENV_VARIABLE in os.environ:
return os.environ[_DEFAULT_ENV_VARIABLE]
return None
@staticmethod
def _environment_discovery_url():
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
def __init__(self,
tpu=None,
zone=None,
@ -204,11 +107,11 @@ class TPUClusterResolver(ClusterResolver):
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://',
then it is assumed to not correspond with a Cloud TPU and will
instead be passed as the session master and no ClusterSpec propagation
will be done. In the future, this may also support a list of strings
when multiple Cloud TPUs are used.
string, the string 'local', or a string that begins with 'grpc://', then
it is assumed to not correspond with a Cloud TPU and will instead be
passed as the session master and no ClusterSpec propagation will be
done. In the future, this may also support a list of strings when
multiple Cloud TPUs are used.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
@ -232,6 +135,7 @@ class TPUClusterResolver(ClusterResolver):
filled in produce an absolute URL to the discovery document for that
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
this.
**kwargs: Extra keyword arguments passed to CloudTPUClient.
Raises:
ImportError: If the googleapiclient is not installed.
@ -239,92 +143,32 @@ class TPUClusterResolver(ClusterResolver):
RuntimeError: If an empty TPU name is specified and this is running in a
Google Cloud environment.
"""
if isinstance(tpu, list):
if not tpu:
raise ValueError('At least one TPU must be specified.')
if len(tpu) != 1:
raise NotImplementedError(
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
in_gke = self._in_gke()
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
tpu = self._gke_endpoints()
else:
tpu = self._env_var_fallback()
if tpu is None:
raise ValueError('Please provide a TPU Name to connect to.')
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
# If we are running in Cloud and don't specify a TPU name
if is_running_in_gce() and not self._tpu:
raise RuntimeError('You need to specify a TPU Name if you are running in '
'the Google Cloud environment.')
self._cloud_tpu_client = CloudTPUClient(
tpu=tpu,
zone=zone,
project=project,
credentials=credentials,
service=service,
discovery_url=discovery_url)
self._tpu = self._cloud_tpu_client.name()
# By default the task_type is 'worker` and the task_id is 0 (which is the
# first worker in the task).
self.task_type = job_name
self.task_id = 0
if self._is_local_tpu():
self.rpc_layer = None
else:
self._environment = ''
self.rpc_layer = 'grpc'
# Setting this overrides the return value of self._should_resolve()
self._should_resolve_override = None
# We strip out the protocol if it is included, and override the
# shouldResolve function to never resolve. We are adding the protocol back
# in later in self.master().
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
tpu = tpu[len(self.rpc_layer + '://'):]
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._should_resolve_override = False
# Whether we should actually attempt to contact Cloud APIs
should_resolve = self._should_resolve()
# We error out if we are in a non-Cloud environment which cannot talk to the
# Cloud APIs using the standard class and a special object is not passed in.
self._service = service
if (self._service is None and should_resolve and
not _GOOGLE_API_CLIENT_INSTALLED):
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the TPU cluster resolver. Execute: '
'`pip install --upgrade google-api-python-client '
'oauth2client` to install with pip.')
# We save user-passed credentials, unless the user didn't pass in anything.
self._credentials = credentials
if (credentials == 'default' and should_resolve and
_GOOGLE_API_CLIENT_INSTALLED):
self._credentials = None
# Automatically detect project and zone if unspecified.
if not project and should_resolve:
project = compat.as_str(
self._request_compute_metadata('project/project-id'))
if not zone and should_resolve:
zone_path = compat.as_str(self._request_compute_metadata('instance/zone'))
zone = zone_path.split('/')[-1]
self._project = project
self._zone = zone
self._discovery_url = self._environment_discovery_url() or discovery_url
self._coordinator_name = coordinator_name
if (coordinator_name and not coordinator_address and
(should_resolve or in_gke)):
if (coordinator_name and not coordinator_address):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
def __enter__(self):
self._cloud_tpu_client.enter()
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
self._cloud_tpu_client.exit(type, value, traceback)
def master(self, task_type=None, task_id=None, rpc_layer=None):
"""Get the Master string to be used for the session.
@ -350,35 +194,27 @@ class TPUClusterResolver(ClusterResolver):
Raises:
ValueError: If none of the TPUs specified exists.
"""
if self._should_resolve():
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
cluster_spec = self.cluster_spec()
if task_type is not None and task_id is not None:
# task_type and task_id is from the function parameter
master = cluster_spec.task_address(task_type, task_id)
elif self.task_type is not None and self.task_id is not None:
# task_type and task_id is from the object
master = cluster_spec.task_address(self.task_type, self.task_id)
else:
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
cluster_spec = self.cluster_spec()
if task_type is not None and task_id is not None:
# task_type and task_id is from the function parameter
master = cluster_spec.task_address(task_type, task_id)
elif self.task_type is not None and self.task_id is not None:
# task_type and task_id is from the object
master = cluster_spec.task_address(self.task_type, self.task_id)
else:
if isinstance(self._tpu, (bytes, bytearray)):
master = compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR)[0]
else:
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
return format_master_url(master, rpc_layer or self.rpc_layer)
# by default we take the first item in the cluster with the right name
job_tasks = cluster_spec.job_tasks(self.task_type)
if not job_tasks:
raise ValueError('No TPUs with the specified names exist.')
master = job_tasks[0]
return format_master_url(master, 'grpc')
def get_master(self):
return self.master()
def get_job_name(self):
if ops.executing_eagerly_outside_functions() or self._should_resolve(
) or is_running_in_gce():
return self.task_type
return self.task_type
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
@ -402,65 +238,20 @@ class TPUClusterResolver(ClusterResolver):
# tasks and
# a. Create a ClusterSpec with the coordinator
# b. Create a ClusterSpec without the coordinator
# 3. [Other (legacy non-gRPC).] We should return None.
############################################################################
if self._should_resolve():
# Case 1.
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(compat.as_text(self._tpu), response['state']))
if 'networkEndpoints' in response:
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in response['networkEndpoints']
]
else:
# Fall back to the deprecated response format
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
worker_list = [instance_url]
cluster_spec = {self.task_type: worker_list}
else:
is_eager = ops.executing_eagerly_outside_functions()
if self.rpc_layer is None and not is_eager:
# Case 3.
return None
# Case 2.
tpus = []
for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR):
# We are working around the fact that GKE environment variable that is
# supplied to us has the protocol string embedded in it, but we want
# to strip it out for the ClusterSpec.
if (self.rpc_layer is not None and
tpu.startswith(self.rpc_layer + '://')):
tpus.append(tpu[len(self.rpc_layer + '://'):])
else:
tpus.append(tpu)
cluster_spec = {self.task_type: tpus}
network_endpoints = self._cloud_tpu_client.network_endpoints()
worker_list = [
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
for endpoint in network_endpoints
]
cluster_spec = {self.task_type: worker_list}
if self._coordinator_address:
# {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
def _fetch_cloud_tpu_metadata(self):
"""Returns the TPU metadata object from the TPU Get API call."""
try:
full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu))
service = self._tpu_service()
request = service.projects().locations().nodes().get(name=full_name)
return request.execute()
except Exception as e:
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
"doublecheck the tpu argument in the TPUClusterResolver "
"constructor. Exception: %s" % (self._tpu, e))
def num_accelerators(self,
task_type=None,
task_id=None,
@ -510,12 +301,11 @@ class TPUClusterResolver(ClusterResolver):
return self._environment
def _start_local_server(self):
address = compat.as_text(
self._request_compute_metadata('instance/network-interfaces/0/ip'))
self._server = server_lib.Server(
{
'local': ['0.0.0.0:0']
}, protocol='grpc', config=None, start=True)
address = compat.as_text(self._cloud_tpu_client.get_local_ip())
self._server = server_lib.Server({'local': ['0.0.0.0:0']},
protocol='grpc',
config=None,
start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))

View File

@ -25,6 +25,7 @@ from six.moves.urllib.error import URLError
from tensorflow.python import framework
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.eager.context import LogicalDevice
from tensorflow.python.framework import errors
@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
mock = test.mock
@ -58,8 +58,8 @@ class MockNodeClass(object):
return MockRequestClass(name, self._tpu_map)
def mock_request_compute_metadata(cls, *args, **kwargs):
del cls, kwargs # Unused.
def mock_request_compute_metadata(*args, **kwargs):
del kwargs # Unused.
if args[0] == 'project/project-id':
return 'test-project'
elif args[0] == 'instance/zone':
@ -107,12 +107,12 @@ class TPUClusterResolverTest(test.TestCase):
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(expected_proto,
server_lib.ClusterSpec(
cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(expected_proto,
server_lib.ClusterSpec(
cluster_spec.as_dict()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
def mock_service_client(self, tpu_map=None):
@ -130,26 +130,19 @@ class TPUClusterResolverTest(test.TestCase):
return mock_client
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_running_in_gce)
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_running_in_gce)
def testCheckRunningInGceWithNoTpuName(self):
with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
with self.assertRaisesRegexp(ValueError,
'Please provide a TPU Name to connect to.*'):
resolver.TPUClusterResolver(tpu='')
@mock.patch.object(six.moves.urllib.request,
'urlopen',
@mock.patch.object(six.moves.urllib.request, 'urlopen',
mock_running_in_gce_urlopen)
def testIsRunningInGce(self):
self.assertTrue(resolver.is_running_in_gce())
@mock.patch.object(six.moves.urllib.request,
'urlopen',
mock_not_running_in_gce_urlopen)
def testIsNotRunningInGce(self):
self.assertFalse(resolver.is_running_in_gce())
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadata(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
@ -181,8 +174,8 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
@ -207,8 +200,8 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testNotReadyCloudTpu(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
@ -306,8 +299,8 @@ class TPUClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
mock_request_compute_metadata)
def testPodResolution(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
@ -425,17 +418,10 @@ class TPUClusterResolverTest(test.TestCase):
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map={}))
self.assertEqual(should_resolve, cluster_resolver._should_resolve(),
self.assertEqual(should_resolve,
cluster_resolver._cloud_tpu_client.api_available(),
"TPU: '%s'" % tpu)
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
def testShouldResolveNoName(self):
self.verifyShouldResolve('', False)
def testShouldResolveLocal(self):
self.verifyShouldResolve('local', False)
def testShouldResolveGrpc(self):
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
@ -461,11 +447,6 @@ class TPUClusterResolverTest(test.TestCase):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(resolver.TPUClusterResolver._in_gke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
compat.as_bytes(
resolver.TPUClusterResolver._gke_endpoints()))
cluster_resolver = resolver.TPUClusterResolver()
self.assertEqual(
@ -489,15 +470,6 @@ class TPUClusterResolverTest(test.TestCase):
'grpc://10.120.27.8:8470')
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(resolver.TPUClusterResolver._in_gke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470,'
'grpc://10.120.27.6:8470,'
'grpc://10.120.27.7:8470,'
'grpc://10.120.27.8:8470'),
compat.as_bytes(
resolver.TPUClusterResolver._gke_endpoints()))
cluster_resolver = resolver.TPUClusterResolver()
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
@ -516,17 +488,9 @@ class TPUClusterResolverTest(test.TestCase):
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual(
'https://{api}.internal/{apiVersion}',
(resolver.TPUClusterResolver._environment_discovery_url()))
def testEnvironmentAndRpcDetectionForGrpcString(self):
def testRpcDetectionForGrpcString(self):
cluster_resolver = resolver.TPUClusterResolver(
tpu='grpc://10.1.2.3:8470')
self.assertEqual(cluster_resolver.environment, '')
self.assertEqual(cluster_resolver.rpc_layer, 'grpc')
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testOverrideTaskTypeAndIndexAndGetMaster(self):
@ -569,13 +533,8 @@ class TPUClusterResolverTest(test.TestCase):
cluster_resolver.task_id = 3
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
self.assertEqual(
cluster_resolver.master(
task_type='worker', task_id=2, rpc_layer='test'),
'test://10.2.3.6:8470')
def testGetDeviceDictAndCoresWithTPUs(self):
device_names = [
devices = [
'/job:tpu_worker/task:0/device:TPU:0',
'/job:tpu_worker/task:1/device:TPU:1',
'/job:tpu_worker/task:2/device:TPU:0',
@ -586,8 +545,7 @@ class TPUClusterResolverTest(test.TestCase):
'/job:tpu_worker/task:3/device:TPU:5',
]
device_list = [
session._DeviceAttributes(
name, 'TPU', 1024, 0) for name in device_names
session._DeviceAttributes(name, 'TPU', 1024, 0) for name in devices
]
device_details = resolver.TPUClusterResolver._get_device_dict_and_cores(
@ -600,7 +558,7 @@ class TPUClusterResolverTest(test.TestCase):
'3': ['1', '5']})
def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
device_names = [
devices = [
'/job:tpu_worker/task:0/device:CPU:0',
'/job:tpu_worker/task:1/device:CPU:0',
'/job:tpu_worker/task:2/device:CPU:0',
@ -611,8 +569,7 @@ class TPUClusterResolverTest(test.TestCase):
'/job:tpu_worker/task:3/device:GPU:1',
]
device_list = [
session._DeviceAttributes(
name, 'XLA', 1024, 0) for name in device_names
session._DeviceAttributes(name, 'XLA', 1024, 0) for name in devices
]
device_dict, num_cores =\
@ -639,8 +596,7 @@ class TPUClusterResolverTest(test.TestCase):
@mock.patch.object(framework.config, 'list_logical_devices')
@mock.patch.object(session.BaseSession, 'list_devices')
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
def testNumAcceleratorsSuccess(self, mock_list_devices,
mock_eager_list_devices):
devices = [
@ -660,16 +616,73 @@ class TPUClusterResolverTest(test.TestCase):
mock_eager_list_devices.return_value = devices
mock_list_devices.return_value = device_list
cluster_resolver = resolver.TPUClusterResolver(tpu='')
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health':
'HEALTHY',
'networkEndpoints': [
{
'ipAddress': '10.2.3.4',
'port': 8470,
},
{
'ipAddress': '10.2.3.5',
'port': 8470,
},
{
'ipAddress': '10.2.3.6',
'port': 8470,
},
{
'ipAddress': '10.2.3.7',
'port': 8470,
},
]
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu='test-tpu-1',
service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
@mock.patch.object(framework.config, 'list_logical_devices')
@mock.patch.object(session.BaseSession, 'list_devices')
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
def testNumAcceleratorsRetryFailure(self, mock_list_devices,
mock_eager_list_devices):
cluster_resolver = resolver.TPUClusterResolver(tpu='')
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health':
'HEALTHY',
'networkEndpoints': [
{
'ipAddress': '10.2.3.4',
'port': 8470,
},
{
'ipAddress': '10.2.3.5',
'port': 8470,
},
{
'ipAddress': '10.2.3.6',
'port': 8470,
},
{
'ipAddress': '10.2.3.7',
'port': 8470,
},
]
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu='test-tpu-1',
service=self.mock_service_client(tpu_map=tpu_map))
mock_list_devices.side_effect = errors.DeadlineExceededError(
None, None, 'timeout')
mock_eager_list_devices.side_effect = errors.DeadlineExceededError(

View File

@ -84,20 +84,10 @@ class _TPUPollingThread(threading.Thread):
return
while self._running:
response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
logging.warning(
'TPUPollingThread found TPU %s in state %s, and health %s.',
self._cluster._tpu, response['state'], # pylint: disable=protected-access
response.get('health', 'UNKNOWN'))
if 'state' in response and response['state'] in [
'TERMINATED', 'PREEMPTED'
]:
recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
if not recoverable:
logging.warning(
'TPU node %s reached an unrecoverable state %s, '
'terminating training.',
self._cluster._tpu, # pylint: disable=protected-access
response['state'])
'TPUPollingThread found TPU %s in state %s',
self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access
os._exit(1) # pylint: disable=protected-access
time.sleep(self._interval)