Refactor Tpu Cluster Resolver.
PiperOrigin-RevId: 282612898 Change-Id: I826f45bd86f8d986631efe4c36cb87321993087a
This commit is contained in:
parent
602e65243d
commit
7d801fe575
@ -60,11 +60,56 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cloud_tpu_client",
|
||||
srcs = ["cloud_tpu_client.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "cloud_tpu_client_py_test",
|
||||
size = "small",
|
||||
srcs = ["cloud_tpu_client_test.py"],
|
||||
additional_deps = [
|
||||
":cloud_tpu_client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "cloud_tpu_client_test.py",
|
||||
python_version = "PY3",
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "cloud_tpu_client_py2_test",
|
||||
size = "small",
|
||||
srcs = ["cloud_tpu_client_test.py"],
|
||||
additional_deps = [
|
||||
":cloud_tpu_client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
],
|
||||
grpc_enabled = True,
|
||||
main = "cloud_tpu_client_test.py",
|
||||
python_version = "PY2",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu_cluster_resolver_py",
|
||||
srcs = ["tpu_cluster_resolver.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cloud_tpu_client",
|
||||
":base_cluster_resolver_py",
|
||||
"//tensorflow/python:training_server_lib",
|
||||
] + tf_additional_rpc_deps(),
|
||||
|
@ -0,0 +1,227 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Cloud TPU Client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from six.moves.urllib import request
|
||||
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from apiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
|
||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||
_ENDPOINTS_SEPARATOR = ','
|
||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
||||
_DEFAULT_ENDPOINT_PORT = '8470'
|
||||
|
||||
|
||||
def _environment_discovery_url():
|
||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||
|
||||
|
||||
def _request_compute_metadata(path):
|
||||
req = request.Request(
|
||||
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = request.urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
|
||||
|
||||
def _environment_var_to_network_endpoints(endpoints):
|
||||
"""Yields a dict with ip address and port."""
|
||||
for endpoint in endpoints.split(compat.as_text(',')):
|
||||
grpc_prefix = compat.as_text('grpc://')
|
||||
if endpoint.startswith(grpc_prefix):
|
||||
endpoint = endpoint.split(grpc_prefix)[1]
|
||||
parts = endpoint.split(compat.as_text(':'))
|
||||
ip_address = parts[0]
|
||||
port = _DEFAULT_ENDPOINT_PORT
|
||||
if len(parts) > 1:
|
||||
port = parts[1]
|
||||
yield {
|
||||
'ipAddress': compat.as_text(ip_address),
|
||||
'port': compat.as_text(port)
|
||||
}
|
||||
|
||||
|
||||
def _get_tpu_name(tpu):
|
||||
if tpu:
|
||||
return tpu
|
||||
|
||||
for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
|
||||
if e in os.environ:
|
||||
return os.environ[e]
|
||||
return None
|
||||
|
||||
|
||||
class CloudTPUClient(object):
|
||||
"""Client for working with the Cloud TPU API.
|
||||
|
||||
This client is intended to be used for resolving tpu name to ip addresses.
|
||||
|
||||
It's recommended to use this library as a contextlib to utilize all
|
||||
functionality.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tpu=None,
|
||||
zone=None,
|
||||
project=None,
|
||||
credentials='default',
|
||||
service=None,
|
||||
discovery_url=None):
|
||||
if isinstance(tpu, list):
|
||||
if not tpu:
|
||||
raise ValueError('At least one TPU must be specified.')
|
||||
if len(tpu) != 1:
|
||||
raise NotImplementedError(
|
||||
'Using multiple TPUs in a single session is not yet implemented')
|
||||
tpu = tpu[0]
|
||||
|
||||
tpu = _get_tpu_name(tpu)
|
||||
|
||||
if tpu is None:
|
||||
raise ValueError('Please provide a TPU Name to connect to.')
|
||||
|
||||
self._tpu = compat.as_text(tpu)
|
||||
|
||||
self._use_api = not tpu.startswith('grpc://')
|
||||
self._service = service
|
||||
|
||||
self._credentials = None
|
||||
self._project = None
|
||||
self._zone = None
|
||||
self._discovery_url = None
|
||||
if self._use_api:
|
||||
if credentials != 'default':
|
||||
self._credentials = credentials
|
||||
# Automaically detect project and zone if unspecified.
|
||||
if project:
|
||||
self._project = project
|
||||
else:
|
||||
self._project = compat.as_str(
|
||||
_request_compute_metadata('project/project-id'))
|
||||
if zone:
|
||||
self._zone = zone
|
||||
else:
|
||||
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
|
||||
self._zone = zone_path.split('/')[-1]
|
||||
self._discovery_url = _environment_discovery_url() or discovery_url
|
||||
|
||||
def _tpu_service(self):
|
||||
"""Creates a new Cloud TPU API object.
|
||||
|
||||
This works around an issue where the underlying HTTP connection sometimes
|
||||
times out when the script has been running for too long. Other methods in
|
||||
this object call this method to get a new API object whenever they need
|
||||
to communicate with the Cloud API.
|
||||
|
||||
Returns:
|
||||
A Google Cloud TPU API object.
|
||||
"""
|
||||
if self._service:
|
||||
return self._service
|
||||
|
||||
credentials = self._credentials
|
||||
if credentials is None or credentials == 'default':
|
||||
credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if self._discovery_url:
|
||||
return discovery.build(
|
||||
'tpu',
|
||||
'v1',
|
||||
credentials=credentials,
|
||||
discoveryServiceUrl=self._discovery_url,
|
||||
cache_discovery=False)
|
||||
else:
|
||||
return discovery.build(
|
||||
'tpu', 'v1', credentials=credentials, cache_discovery=False)
|
||||
|
||||
def _fetch_cloud_tpu_metadata(self):
|
||||
"""Returns the TPU metadata object from the TPU Get API call."""
|
||||
try:
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
service = self._tpu_service()
|
||||
r = service.projects().locations().nodes().get(name=full_name)
|
||||
return r.execute()
|
||||
except Exception as e:
|
||||
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
|
||||
'doublecheck the tpu argument in the TPUClusterResolver '
|
||||
'constructor. Exception: %s' % (self._tpu, e))
|
||||
|
||||
def __enter__(self):
|
||||
self._open = True
|
||||
|
||||
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
|
||||
del type, value, traceback
|
||||
|
||||
def recoverable(self):
|
||||
"""Returns true if the TPU is in a state where training should eventually resume.
|
||||
|
||||
If false the TPU is in a unrecoverable state and should be recreated.
|
||||
"""
|
||||
state = self.state()
|
||||
if state and state in ['TERMINATED', 'PREEMPTED']:
|
||||
return False
|
||||
return True
|
||||
|
||||
def state(self):
|
||||
"""Return state of the TPU."""
|
||||
if self._use_api:
|
||||
metadata = self._fetch_cloud_tpu_metadata()
|
||||
if 'state' in metadata:
|
||||
return metadata['state']
|
||||
|
||||
return None
|
||||
|
||||
def api_available(self):
|
||||
"""Return if the Cloud TPU API is available, if not certain features will not work."""
|
||||
return self._use_api
|
||||
|
||||
def name(self):
|
||||
"""Return the name of the tpu, or the ip address if name is not provided."""
|
||||
return self._tpu
|
||||
|
||||
def get_local_ip(self):
|
||||
"""Return the local ip address of the Google Cloud VM the workload is running on."""
|
||||
return _request_compute_metadata('instance/network-interfaces/0/ip')
|
||||
|
||||
def network_endpoints(self):
|
||||
"""Return a list of tpu endpoints."""
|
||||
if not self._use_api:
|
||||
return list(_environment_var_to_network_endpoints(self._tpu))
|
||||
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(compat.as_text(self._tpu), response['state']))
|
||||
if 'networkEndpoints' in response:
|
||||
return response['networkEndpoints']
|
||||
else:
|
||||
return [{'ipAddress': response['ipAddress'], 'port': response['port']}]
|
@ -0,0 +1,251 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Tests for cloud_tpu_client."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
mock = test.mock
|
||||
|
||||
|
||||
def mock_request_compute_metadata(path):
|
||||
if path == 'project/project-id':
|
||||
return 'test-project'
|
||||
elif path == 'instance/zone':
|
||||
return 'projects/test-project/locations/us-central1-c'
|
||||
elif path == 'instance/network-interfaces/0/ip':
|
||||
return '10.128.1.2'
|
||||
return ''
|
||||
|
||||
|
||||
class MockRequestClass(object):
|
||||
|
||||
def __init__(self, name, tpu_map):
|
||||
self._name = name
|
||||
self._tpu_map = tpu_map
|
||||
|
||||
def execute(self):
|
||||
if self._name in self._tpu_map:
|
||||
return self._tpu_map[self._name]
|
||||
else:
|
||||
raise KeyError('Resource %s was not found' % self._name)
|
||||
|
||||
|
||||
class MockNodeClass(object):
|
||||
|
||||
def __init__(self, tpu_map):
|
||||
self._tpu_map = tpu_map
|
||||
|
||||
def get(self, name):
|
||||
return MockRequestClass(name, self._tpu_map)
|
||||
|
||||
|
||||
class CloudTpuClientTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CloudTpuClientTest, self).setUp()
|
||||
if 'TPU_API_DISCOVERY_URL' in os.environ:
|
||||
del os.environ['TPU_API_DISCOVERY_URL']
|
||||
if 'TPU_NAME' in os.environ:
|
||||
del os.environ['TPU_NAME']
|
||||
|
||||
def mock_service_client(self, tpu_map=None):
|
||||
if tpu_map is None:
|
||||
tpu_map = {}
|
||||
|
||||
mock_locations = mock.MagicMock()
|
||||
mock_locations.nodes.return_value = MockNodeClass(tpu_map)
|
||||
|
||||
mock_project = mock.MagicMock()
|
||||
mock_project.locations.return_value = mock_locations
|
||||
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.projects.return_value = mock_project
|
||||
return mock_client
|
||||
|
||||
def testEnvironmentDiscoveryUrl(self):
|
||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||
(cloud_tpu_client._environment_discovery_url()))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
|
||||
list(
|
||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
'grpc://1.2.3.4:2000')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||
list(
|
||||
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:2000,5.6.7.8:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
|
||||
|
||||
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
|
||||
self.assertEqual(
|
||||
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||
{'ipAddress': '5.6.7.8', 'port': '8470'}],
|
||||
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||
'1.2.3.4:2000,grpc://5.6.7.8')))
|
||||
|
||||
def testInitializeNoArguments(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Please provide a TPU Name to connect to.'):
|
||||
cloud_tpu_client.CloudTPUClient()
|
||||
|
||||
def testInitializeMultiElementTpuArray(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Using multiple TPUs in a single session is not yet implemented'):
|
||||
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
|
||||
|
||||
def assertClientContains(self, client):
|
||||
self.assertEqual('tpu_name', client._tpu)
|
||||
self.assertEqual(True, client._use_api)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual('test-project', client._project)
|
||||
self.assertEqual('us-central1-c', client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
self.assertEqual([{
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470'
|
||||
}], client.network_endpoints())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
||||
os.environ['TPU_NAME'] = 'tpu_name'
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertClientContains(client)
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeTpuName(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'health': 'HEALTHY'
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertClientContains(client)
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testInitializeIpAddress(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
|
||||
self.assertEqual(False, client._use_api)
|
||||
self.assertEqual(None, client._service)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual(None, client._project)
|
||||
self.assertEqual(None, client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
self.assertEqual([{
|
||||
'ipAddress': '1.2.3.4',
|
||||
'port': '8470'
|
||||
}], client.network_endpoints())
|
||||
|
||||
def testInitializeWithoutMetadata(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
tpu='tpu_name', project='project', zone='zone')
|
||||
self.assertEqual('tpu_name', client._tpu)
|
||||
self.assertEqual(True, client._use_api)
|
||||
self.assertEqual(None, client._service)
|
||||
self.assertEqual(None, client._credentials)
|
||||
self.assertEqual('project', client._project)
|
||||
self.assertEqual('zone', client._zone)
|
||||
self.assertEqual(None, client._discovery_url)
|
||||
|
||||
def testRecoverableNoApiAccess(self):
|
||||
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||
self.assertEqual(True, client.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverableNoState(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(True, client.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverableReady(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'READY',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(True, client.recoverable())
|
||||
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRecoverablePreempted(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||
'ipAddress': '10.1.2.3',
|
||||
'port': '8470',
|
||||
'state': 'PREEMPTED',
|
||||
}
|
||||
}
|
||||
client = cloud_tpu_client.CloudTPUClient(
|
||||
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(False, client.recoverable())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -19,61 +19,30 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
|
||||
from six.moves import urllib
|
||||
from six.moves.urllib.error import URLError
|
||||
from six.moves.urllib.request import Request
|
||||
from six.moves.urllib.request import urlopen
|
||||
|
||||
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||
try:
|
||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||
|
||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||
_ENDPOINTS_SEPARATOR = ','
|
||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
def is_running_in_gce():
|
||||
return True
|
||||
|
||||
|
||||
_TPU_DEVICE_REGEX = re.compile(
|
||||
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
|
||||
_TPU_CONN_RETRIES = 120
|
||||
|
||||
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
||||
|
||||
DeviceDetails = collections.namedtuple(
|
||||
'DeviceDetails', ['device_map', 'total_cores'])
|
||||
|
||||
|
||||
def is_running_in_gce():
|
||||
"""Checks for GCE presence by attempting to query the metadata service."""
|
||||
try:
|
||||
req = Request(
|
||||
'%s/computeMetadata/v1' % _GCE_METADATA_ENDPOINT,
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = urllib.request.urlopen(req, timeout=1)
|
||||
info = resp.info()
|
||||
if 'Metadata-Flavor' in info and info['Metadata-Flavor'] == 'Google':
|
||||
return True
|
||||
except URLError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
||||
class TPUClusterResolver(ClusterResolver):
|
||||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
@ -89,53 +58,6 @@ class TPUClusterResolver(ClusterResolver):
|
||||
Google internal
|
||||
"""
|
||||
|
||||
def _tpu_service(self):
|
||||
"""Creates a new Cloud TPU API object.
|
||||
|
||||
This works around an issue where the underlying HTTP connection sometimes
|
||||
times out when the script has been running for too long. Other methods in
|
||||
this object call this method to get a new API object whenever they need
|
||||
to communicate with the Cloud API.
|
||||
|
||||
Returns:
|
||||
A Google Cloud TPU API object.
|
||||
"""
|
||||
if self._service:
|
||||
return self._service
|
||||
|
||||
credentials = self._credentials
|
||||
if credentials is None or credentials == 'default':
|
||||
credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if self._discovery_url:
|
||||
return discovery.build(
|
||||
'tpu',
|
||||
'v1',
|
||||
credentials=credentials,
|
||||
discoveryServiceUrl=self._discovery_url,
|
||||
cache_discovery=False)
|
||||
else:
|
||||
return discovery.build(
|
||||
'tpu', 'v1', credentials=credentials, cache_discovery=False)
|
||||
|
||||
def _request_compute_metadata(self, path):
|
||||
req = Request('%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
resp = urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
|
||||
def _is_local_tpu(self):
|
||||
return (
|
||||
self._tpu == compat.as_bytes('') or
|
||||
self._tpu == compat.as_bytes('local'))
|
||||
|
||||
def _should_resolve(self):
|
||||
if isinstance(self._should_resolve_override, bool):
|
||||
return self._should_resolve_override
|
||||
else:
|
||||
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
|
||||
self._is_local_tpu())
|
||||
|
||||
@staticmethod
|
||||
def _get_device_dict_and_cores(devices):
|
||||
"""Returns a dict of hosts to cores and total cores given devices names.
|
||||
@ -168,25 +90,6 @@ class TPUClusterResolver(ClusterResolver):
|
||||
'should never happen. Devices: {}'.format(device_dict))
|
||||
return num_cores_per_host_set.pop()
|
||||
|
||||
@staticmethod
|
||||
def _in_gke():
|
||||
"""When running in GKE, the environment variable will be set."""
|
||||
return _GKE_ENV_VARIABLE in os.environ
|
||||
|
||||
@staticmethod
|
||||
def _gke_endpoints():
|
||||
return os.environ[_GKE_ENV_VARIABLE]
|
||||
|
||||
@staticmethod
|
||||
def _env_var_fallback():
|
||||
if _DEFAULT_ENV_VARIABLE in os.environ:
|
||||
return os.environ[_DEFAULT_ENV_VARIABLE]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _environment_discovery_url():
|
||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||
|
||||
def __init__(self,
|
||||
tpu=None,
|
||||
zone=None,
|
||||
@ -204,11 +107,11 @@ class TPUClusterResolver(ClusterResolver):
|
||||
|
||||
Args:
|
||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
||||
string, the string 'local', or a string that begins with 'grpc://',
|
||||
then it is assumed to not correspond with a Cloud TPU and will
|
||||
instead be passed as the session master and no ClusterSpec propagation
|
||||
will be done. In the future, this may also support a list of strings
|
||||
when multiple Cloud TPUs are used.
|
||||
string, the string 'local', or a string that begins with 'grpc://', then
|
||||
it is assumed to not correspond with a Cloud TPU and will instead be
|
||||
passed as the session master and no ClusterSpec propagation will be
|
||||
done. In the future, this may also support a list of strings when
|
||||
multiple Cloud TPUs are used.
|
||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||
will try to discover from the GCE metadata service.
|
||||
@ -232,6 +135,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
filled in produce an absolute URL to the discovery document for that
|
||||
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
||||
this.
|
||||
**kwargs: Extra keyword arguments passed to CloudTPUClient.
|
||||
|
||||
Raises:
|
||||
ImportError: If the googleapiclient is not installed.
|
||||
@ -239,92 +143,32 @@ class TPUClusterResolver(ClusterResolver):
|
||||
RuntimeError: If an empty TPU name is specified and this is running in a
|
||||
Google Cloud environment.
|
||||
"""
|
||||
if isinstance(tpu, list):
|
||||
if not tpu:
|
||||
raise ValueError('At least one TPU must be specified.')
|
||||
if len(tpu) != 1:
|
||||
raise NotImplementedError(
|
||||
'Using multiple TPUs in a single session is not yet implemented')
|
||||
tpu = tpu[0]
|
||||
|
||||
in_gke = self._in_gke()
|
||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
||||
if tpu is None:
|
||||
if in_gke:
|
||||
tpu = self._gke_endpoints()
|
||||
else:
|
||||
tpu = self._env_var_fallback()
|
||||
|
||||
if tpu is None:
|
||||
raise ValueError('Please provide a TPU Name to connect to.')
|
||||
|
||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||
|
||||
# If we are running in Cloud and don't specify a TPU name
|
||||
if is_running_in_gce() and not self._tpu:
|
||||
raise RuntimeError('You need to specify a TPU Name if you are running in '
|
||||
'the Google Cloud environment.')
|
||||
self._cloud_tpu_client = CloudTPUClient(
|
||||
tpu=tpu,
|
||||
zone=zone,
|
||||
project=project,
|
||||
credentials=credentials,
|
||||
service=service,
|
||||
discovery_url=discovery_url)
|
||||
|
||||
self._tpu = self._cloud_tpu_client.name()
|
||||
# By default the task_type is 'worker` and the task_id is 0 (which is the
|
||||
# first worker in the task).
|
||||
self.task_type = job_name
|
||||
self.task_id = 0
|
||||
|
||||
if self._is_local_tpu():
|
||||
self.rpc_layer = None
|
||||
else:
|
||||
self._environment = ''
|
||||
self.rpc_layer = 'grpc'
|
||||
|
||||
# Setting this overrides the return value of self._should_resolve()
|
||||
self._should_resolve_override = None
|
||||
|
||||
# We strip out the protocol if it is included, and override the
|
||||
# shouldResolve function to never resolve. We are adding the protocol back
|
||||
# in later in self.master().
|
||||
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
|
||||
tpu = tpu[len(self.rpc_layer + '://'):]
|
||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||
self._should_resolve_override = False
|
||||
|
||||
# Whether we should actually attempt to contact Cloud APIs
|
||||
should_resolve = self._should_resolve()
|
||||
|
||||
# We error out if we are in a non-Cloud environment which cannot talk to the
|
||||
# Cloud APIs using the standard class and a special object is not passed in.
|
||||
self._service = service
|
||||
if (self._service is None and should_resolve and
|
||||
not _GOOGLE_API_CLIENT_INSTALLED):
|
||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||
'before using the TPU cluster resolver. Execute: '
|
||||
'`pip install --upgrade google-api-python-client '
|
||||
'oauth2client` to install with pip.')
|
||||
|
||||
# We save user-passed credentials, unless the user didn't pass in anything.
|
||||
self._credentials = credentials
|
||||
if (credentials == 'default' and should_resolve and
|
||||
_GOOGLE_API_CLIENT_INSTALLED):
|
||||
self._credentials = None
|
||||
|
||||
# Automatically detect project and zone if unspecified.
|
||||
if not project and should_resolve:
|
||||
project = compat.as_str(
|
||||
self._request_compute_metadata('project/project-id'))
|
||||
if not zone and should_resolve:
|
||||
zone_path = compat.as_str(self._request_compute_metadata('instance/zone'))
|
||||
zone = zone_path.split('/')[-1]
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
|
||||
self._discovery_url = self._environment_discovery_url() or discovery_url
|
||||
|
||||
self._coordinator_name = coordinator_name
|
||||
if (coordinator_name and not coordinator_address and
|
||||
(should_resolve or in_gke)):
|
||||
if (coordinator_name and not coordinator_address):
|
||||
self._start_local_server()
|
||||
else:
|
||||
self._coordinator_address = coordinator_address
|
||||
|
||||
def __enter__(self):
|
||||
self._cloud_tpu_client.enter()
|
||||
|
||||
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
|
||||
self._cloud_tpu_client.exit(type, value, traceback)
|
||||
|
||||
def master(self, task_type=None, task_id=None, rpc_layer=None):
|
||||
"""Get the Master string to be used for the session.
|
||||
|
||||
@ -350,35 +194,27 @@ class TPUClusterResolver(ClusterResolver):
|
||||
Raises:
|
||||
ValueError: If none of the TPUs specified exists.
|
||||
"""
|
||||
if self._should_resolve():
|
||||
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
|
||||
cluster_spec = self.cluster_spec()
|
||||
if task_type is not None and task_id is not None:
|
||||
# task_type and task_id is from the function parameter
|
||||
master = cluster_spec.task_address(task_type, task_id)
|
||||
elif self.task_type is not None and self.task_id is not None:
|
||||
# task_type and task_id is from the object
|
||||
master = cluster_spec.task_address(self.task_type, self.task_id)
|
||||
else:
|
||||
# by default we take the first item in the cluster with the right name
|
||||
job_tasks = cluster_spec.job_tasks(self.task_type)
|
||||
if not job_tasks:
|
||||
raise ValueError('No TPUs with the specified names exist.')
|
||||
master = job_tasks[0]
|
||||
|
||||
cluster_spec = self.cluster_spec()
|
||||
if task_type is not None and task_id is not None:
|
||||
# task_type and task_id is from the function parameter
|
||||
master = cluster_spec.task_address(task_type, task_id)
|
||||
elif self.task_type is not None and self.task_id is not None:
|
||||
# task_type and task_id is from the object
|
||||
master = cluster_spec.task_address(self.task_type, self.task_id)
|
||||
else:
|
||||
if isinstance(self._tpu, (bytes, bytearray)):
|
||||
master = compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR)[0]
|
||||
else:
|
||||
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
|
||||
return format_master_url(master, rpc_layer or self.rpc_layer)
|
||||
# by default we take the first item in the cluster with the right name
|
||||
job_tasks = cluster_spec.job_tasks(self.task_type)
|
||||
if not job_tasks:
|
||||
raise ValueError('No TPUs with the specified names exist.')
|
||||
master = job_tasks[0]
|
||||
return format_master_url(master, 'grpc')
|
||||
|
||||
def get_master(self):
|
||||
return self.master()
|
||||
|
||||
def get_job_name(self):
|
||||
if ops.executing_eagerly_outside_functions() or self._should_resolve(
|
||||
) or is_running_in_gce():
|
||||
return self.task_type
|
||||
return self.task_type
|
||||
|
||||
def cluster_spec(self):
|
||||
"""Returns a ClusterSpec object based on the latest TPU information.
|
||||
@ -402,65 +238,20 @@ class TPUClusterResolver(ClusterResolver):
|
||||
# tasks and
|
||||
# a. Create a ClusterSpec with the coordinator
|
||||
# b. Create a ClusterSpec without the coordinator
|
||||
# 3. [Other (legacy non-gRPC).] We should return None.
|
||||
############################################################################
|
||||
|
||||
if self._should_resolve():
|
||||
# Case 1.
|
||||
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||
(compat.as_text(self._tpu), response['state']))
|
||||
|
||||
if 'networkEndpoints' in response:
|
||||
worker_list = [
|
||||
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
||||
for endpoint in response['networkEndpoints']
|
||||
]
|
||||
else:
|
||||
# Fall back to the deprecated response format
|
||||
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
|
||||
worker_list = [instance_url]
|
||||
|
||||
cluster_spec = {self.task_type: worker_list}
|
||||
else:
|
||||
is_eager = ops.executing_eagerly_outside_functions()
|
||||
if self.rpc_layer is None and not is_eager:
|
||||
# Case 3.
|
||||
return None
|
||||
# Case 2.
|
||||
tpus = []
|
||||
for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR):
|
||||
# We are working around the fact that GKE environment variable that is
|
||||
# supplied to us has the protocol string embedded in it, but we want
|
||||
# to strip it out for the ClusterSpec.
|
||||
if (self.rpc_layer is not None and
|
||||
tpu.startswith(self.rpc_layer + '://')):
|
||||
tpus.append(tpu[len(self.rpc_layer + '://'):])
|
||||
else:
|
||||
tpus.append(tpu)
|
||||
cluster_spec = {self.task_type: tpus}
|
||||
|
||||
network_endpoints = self._cloud_tpu_client.network_endpoints()
|
||||
worker_list = [
|
||||
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
||||
for endpoint in network_endpoints
|
||||
]
|
||||
cluster_spec = {self.task_type: worker_list}
|
||||
if self._coordinator_address:
|
||||
# {1, 2}.a
|
||||
cluster_spec[self._coordinator_name] = [self._coordinator_address]
|
||||
|
||||
return server_lib.ClusterSpec(cluster_spec)
|
||||
|
||||
def _fetch_cloud_tpu_metadata(self):
|
||||
"""Returns the TPU metadata object from the TPU Get API call."""
|
||||
try:
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
service = self._tpu_service()
|
||||
request = service.projects().locations().nodes().get(name=full_name)
|
||||
return request.execute()
|
||||
except Exception as e:
|
||||
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
|
||||
"doublecheck the tpu argument in the TPUClusterResolver "
|
||||
"constructor. Exception: %s" % (self._tpu, e))
|
||||
|
||||
def num_accelerators(self,
|
||||
task_type=None,
|
||||
task_id=None,
|
||||
@ -510,12 +301,11 @@ class TPUClusterResolver(ClusterResolver):
|
||||
return self._environment
|
||||
|
||||
def _start_local_server(self):
|
||||
address = compat.as_text(
|
||||
self._request_compute_metadata('instance/network-interfaces/0/ip'))
|
||||
self._server = server_lib.Server(
|
||||
{
|
||||
'local': ['0.0.0.0:0']
|
||||
}, protocol='grpc', config=None, start=True)
|
||||
address = compat.as_text(self._cloud_tpu_client.get_local_ip())
|
||||
self._server = server_lib.Server({'local': ['0.0.0.0:0']},
|
||||
protocol='grpc',
|
||||
config=None,
|
||||
start=True)
|
||||
# self._server.target is of the form: grpc://ipaddress:port
|
||||
target = compat.as_bytes(self._server.target)
|
||||
splits = target.split(compat.as_bytes(':'))
|
||||
|
@ -25,6 +25,7 @@ from six.moves.urllib.error import URLError
|
||||
|
||||
from tensorflow.python import framework
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||
from tensorflow.python.eager.context import LogicalDevice
|
||||
from tensorflow.python.framework import errors
|
||||
@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
mock = test.mock
|
||||
|
||||
|
||||
@ -58,8 +58,8 @@ class MockNodeClass(object):
|
||||
return MockRequestClass(name, self._tpu_map)
|
||||
|
||||
|
||||
def mock_request_compute_metadata(cls, *args, **kwargs):
|
||||
del cls, kwargs # Unused.
|
||||
def mock_request_compute_metadata(*args, **kwargs):
|
||||
del kwargs # Unused.
|
||||
if args[0] == 'project/project-id':
|
||||
return 'test-project'
|
||||
elif args[0] == 'instance/zone':
|
||||
@ -107,12 +107,12 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
|
||||
self.assertProtoEquals(expected_proto,
|
||||
server_lib.ClusterSpec(
|
||||
cluster_spec.as_cluster_def()).as_cluster_def())
|
||||
self.assertProtoEquals(expected_proto,
|
||||
server_lib.ClusterSpec(
|
||||
cluster_spec.as_dict()).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
|
||||
self.assertProtoEquals(
|
||||
expected_proto,
|
||||
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
|
||||
|
||||
def mock_service_client(self, tpu_map=None):
|
||||
|
||||
@ -130,26 +130,19 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
return mock_client
|
||||
|
||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
||||
mock_is_running_in_gce)
|
||||
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_running_in_gce)
|
||||
def testCheckRunningInGceWithNoTpuName(self):
|
||||
with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Please provide a TPU Name to connect to.*'):
|
||||
resolver.TPUClusterResolver(tpu='')
|
||||
|
||||
@mock.patch.object(six.moves.urllib.request,
|
||||
'urlopen',
|
||||
@mock.patch.object(six.moves.urllib.request, 'urlopen',
|
||||
mock_running_in_gce_urlopen)
|
||||
def testIsRunningInGce(self):
|
||||
self.assertTrue(resolver.is_running_in_gce())
|
||||
|
||||
@mock.patch.object(six.moves.urllib.request,
|
||||
'urlopen',
|
||||
mock_not_running_in_gce_urlopen)
|
||||
def testIsNotRunningInGce(self):
|
||||
self.assertFalse(resolver.is_running_in_gce())
|
||||
|
||||
@mock.patch.object(resolver.TPUClusterResolver,
|
||||
'_request_compute_metadata', mock_request_compute_metadata)
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRetrieveProjectAndZoneFromMetadata(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
@ -181,8 +174,8 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
|
||||
@mock.patch.object(resolver.TPUClusterResolver,
|
||||
'_request_compute_metadata', mock_request_compute_metadata)
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
@ -207,8 +200,8 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
|
||||
@mock.patch.object(resolver.TPUClusterResolver,
|
||||
'_request_compute_metadata', mock_request_compute_metadata)
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testNotReadyCloudTpu(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
@ -306,8 +299,8 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
||||
|
||||
@mock.patch.object(resolver.TPUClusterResolver,
|
||||
'_request_compute_metadata', mock_request_compute_metadata)
|
||||
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||
mock_request_compute_metadata)
|
||||
def testPodResolution(self):
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
@ -425,17 +418,10 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
coordinator_name=None,
|
||||
credentials=None,
|
||||
service=self.mock_service_client(tpu_map={}))
|
||||
self.assertEqual(should_resolve, cluster_resolver._should_resolve(),
|
||||
self.assertEqual(should_resolve,
|
||||
cluster_resolver._cloud_tpu_client.api_available(),
|
||||
"TPU: '%s'" % tpu)
|
||||
|
||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
||||
mock_is_not_running_in_gce)
|
||||
def testShouldResolveNoName(self):
|
||||
self.verifyShouldResolve('', False)
|
||||
|
||||
def testShouldResolveLocal(self):
|
||||
self.verifyShouldResolve('local', False)
|
||||
|
||||
def testShouldResolveGrpc(self):
|
||||
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
||||
|
||||
@ -461,11 +447,6 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||
|
||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||
self.assertTrue(resolver.TPUClusterResolver._in_gke())
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||
compat.as_bytes(
|
||||
resolver.TPUClusterResolver._gke_endpoints()))
|
||||
|
||||
cluster_resolver = resolver.TPUClusterResolver()
|
||||
self.assertEqual(
|
||||
@ -489,15 +470,6 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'grpc://10.120.27.8:8470')
|
||||
|
||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||
self.assertTrue(resolver.TPUClusterResolver._in_gke())
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470,'
|
||||
'grpc://10.120.27.6:8470,'
|
||||
'grpc://10.120.27.7:8470,'
|
||||
'grpc://10.120.27.8:8470'),
|
||||
compat.as_bytes(
|
||||
resolver.TPUClusterResolver._gke_endpoints()))
|
||||
|
||||
cluster_resolver = resolver.TPUClusterResolver()
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||
@ -516,17 +488,9 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||
|
||||
def testEnvironmentDiscoveryUrl(self):
|
||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||
self.assertEqual(
|
||||
'https://{api}.internal/{apiVersion}',
|
||||
(resolver.TPUClusterResolver._environment_discovery_url()))
|
||||
|
||||
def testEnvironmentAndRpcDetectionForGrpcString(self):
|
||||
def testRpcDetectionForGrpcString(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(
|
||||
tpu='grpc://10.1.2.3:8470')
|
||||
self.assertEqual(cluster_resolver.environment, '')
|
||||
self.assertEqual(cluster_resolver.rpc_layer, 'grpc')
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||
|
||||
def testOverrideTaskTypeAndIndexAndGetMaster(self):
|
||||
@ -569,13 +533,8 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
cluster_resolver.task_id = 3
|
||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
|
||||
|
||||
self.assertEqual(
|
||||
cluster_resolver.master(
|
||||
task_type='worker', task_id=2, rpc_layer='test'),
|
||||
'test://10.2.3.6:8470')
|
||||
|
||||
def testGetDeviceDictAndCoresWithTPUs(self):
|
||||
device_names = [
|
||||
devices = [
|
||||
'/job:tpu_worker/task:0/device:TPU:0',
|
||||
'/job:tpu_worker/task:1/device:TPU:1',
|
||||
'/job:tpu_worker/task:2/device:TPU:0',
|
||||
@ -586,8 +545,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'/job:tpu_worker/task:3/device:TPU:5',
|
||||
]
|
||||
device_list = [
|
||||
session._DeviceAttributes(
|
||||
name, 'TPU', 1024, 0) for name in device_names
|
||||
session._DeviceAttributes(name, 'TPU', 1024, 0) for name in devices
|
||||
]
|
||||
|
||||
device_details = resolver.TPUClusterResolver._get_device_dict_and_cores(
|
||||
@ -600,7 +558,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'3': ['1', '5']})
|
||||
|
||||
def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
|
||||
device_names = [
|
||||
devices = [
|
||||
'/job:tpu_worker/task:0/device:CPU:0',
|
||||
'/job:tpu_worker/task:1/device:CPU:0',
|
||||
'/job:tpu_worker/task:2/device:CPU:0',
|
||||
@ -611,8 +569,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'/job:tpu_worker/task:3/device:GPU:1',
|
||||
]
|
||||
device_list = [
|
||||
session._DeviceAttributes(
|
||||
name, 'XLA', 1024, 0) for name in device_names
|
||||
session._DeviceAttributes(name, 'XLA', 1024, 0) for name in devices
|
||||
]
|
||||
|
||||
device_dict, num_cores =\
|
||||
@ -639,8 +596,7 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
@mock.patch.object(framework.config, 'list_logical_devices')
|
||||
@mock.patch.object(session.BaseSession, 'list_devices')
|
||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
||||
mock_is_not_running_in_gce)
|
||||
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
|
||||
def testNumAcceleratorsSuccess(self, mock_list_devices,
|
||||
mock_eager_list_devices):
|
||||
devices = [
|
||||
@ -660,16 +616,73 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
mock_eager_list_devices.return_value = devices
|
||||
mock_list_devices.return_value = device_list
|
||||
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='')
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
{
|
||||
'ipAddress': '10.2.3.4',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.5',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.6',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.7',
|
||||
'port': 8470,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
cluster_resolver = resolver.TPUClusterResolver(
|
||||
project='test-project',
|
||||
zone='us-central1-c',
|
||||
tpu='test-tpu-1',
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
|
||||
|
||||
@mock.patch.object(framework.config, 'list_logical_devices')
|
||||
@mock.patch.object(session.BaseSession, 'list_devices')
|
||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
||||
mock_is_not_running_in_gce)
|
||||
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
|
||||
def testNumAcceleratorsRetryFailure(self, mock_list_devices,
|
||||
mock_eager_list_devices):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='')
|
||||
tpu_map = {
|
||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||
'health':
|
||||
'HEALTHY',
|
||||
'networkEndpoints': [
|
||||
{
|
||||
'ipAddress': '10.2.3.4',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.5',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.6',
|
||||
'port': 8470,
|
||||
},
|
||||
{
|
||||
'ipAddress': '10.2.3.7',
|
||||
'port': 8470,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
cluster_resolver = resolver.TPUClusterResolver(
|
||||
project='test-project',
|
||||
zone='us-central1-c',
|
||||
tpu='test-tpu-1',
|
||||
service=self.mock_service_client(tpu_map=tpu_map))
|
||||
mock_list_devices.side_effect = errors.DeadlineExceededError(
|
||||
None, None, 'timeout')
|
||||
mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
|
||||
|
@ -84,20 +84,10 @@ class _TPUPollingThread(threading.Thread):
|
||||
return
|
||||
|
||||
while self._running:
|
||||
response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
||||
logging.warning(
|
||||
'TPUPollingThread found TPU %s in state %s, and health %s.',
|
||||
self._cluster._tpu, response['state'], # pylint: disable=protected-access
|
||||
response.get('health', 'UNKNOWN'))
|
||||
|
||||
if 'state' in response and response['state'] in [
|
||||
'TERMINATED', 'PREEMPTED'
|
||||
]:
|
||||
recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
|
||||
if not recoverable:
|
||||
logging.warning(
|
||||
'TPU node %s reached an unrecoverable state %s, '
|
||||
'terminating training.',
|
||||
self._cluster._tpu, # pylint: disable=protected-access
|
||||
response['state'])
|
||||
'TPUPollingThread found TPU %s in state %s',
|
||||
self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access
|
||||
os._exit(1) # pylint: disable=protected-access
|
||||
|
||||
time.sleep(self._interval)
|
||||
|
Loading…
Reference in New Issue
Block a user