Refactor Tpu Cluster Resolver.
PiperOrigin-RevId: 282612898 Change-Id: I826f45bd86f8d986631efe4c36cb87321993087a
This commit is contained in:
parent
602e65243d
commit
7d801fe575
@ -60,11 +60,56 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cloud_tpu_client",
|
||||||
|
srcs = ["cloud_tpu_client.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "cloud_tpu_client_py_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["cloud_tpu_client_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":cloud_tpu_client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training_server_lib",
|
||||||
|
],
|
||||||
|
grpc_enabled = True,
|
||||||
|
main = "cloud_tpu_client_test.py",
|
||||||
|
python_version = "PY3",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "cloud_tpu_client_py2_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["cloud_tpu_client_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":cloud_tpu_client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python:training_server_lib",
|
||||||
|
],
|
||||||
|
grpc_enabled = True,
|
||||||
|
main = "cloud_tpu_client_test.py",
|
||||||
|
python_version = "PY2",
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tpu_cluster_resolver_py",
|
name = "tpu_cluster_resolver_py",
|
||||||
srcs = ["tpu_cluster_resolver.py"],
|
srcs = ["tpu_cluster_resolver.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":cloud_tpu_client",
|
||||||
":base_cluster_resolver_py",
|
":base_cluster_resolver_py",
|
||||||
"//tensorflow/python:training_server_lib",
|
"//tensorflow/python:training_server_lib",
|
||||||
] + tf_additional_rpc_deps(),
|
] + tf_additional_rpc_deps(),
|
||||||
|
@ -0,0 +1,227 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# Lint as: python3
|
||||||
|
"""Cloud TPU Client."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from six.moves.urllib import request
|
||||||
|
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
_GOOGLE_API_CLIENT_INSTALLED = True
|
||||||
|
try:
|
||||||
|
from apiclient import discovery # pylint: disable=g-import-not-at-top
|
||||||
|
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
_GOOGLE_API_CLIENT_INSTALLED = False
|
||||||
|
|
||||||
|
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||||
|
_ENDPOINTS_SEPARATOR = ','
|
||||||
|
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||||
|
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||||
|
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
||||||
|
_DEFAULT_ENDPOINT_PORT = '8470'
|
||||||
|
|
||||||
|
|
||||||
|
def _environment_discovery_url():
|
||||||
|
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_compute_metadata(path):
|
||||||
|
req = request.Request(
|
||||||
|
'%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
||||||
|
headers={'Metadata-Flavor': 'Google'})
|
||||||
|
resp = request.urlopen(req)
|
||||||
|
return compat.as_bytes(resp.read())
|
||||||
|
|
||||||
|
|
||||||
|
def _environment_var_to_network_endpoints(endpoints):
|
||||||
|
"""Yields a dict with ip address and port."""
|
||||||
|
for endpoint in endpoints.split(compat.as_text(',')):
|
||||||
|
grpc_prefix = compat.as_text('grpc://')
|
||||||
|
if endpoint.startswith(grpc_prefix):
|
||||||
|
endpoint = endpoint.split(grpc_prefix)[1]
|
||||||
|
parts = endpoint.split(compat.as_text(':'))
|
||||||
|
ip_address = parts[0]
|
||||||
|
port = _DEFAULT_ENDPOINT_PORT
|
||||||
|
if len(parts) > 1:
|
||||||
|
port = parts[1]
|
||||||
|
yield {
|
||||||
|
'ipAddress': compat.as_text(ip_address),
|
||||||
|
'port': compat.as_text(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tpu_name(tpu):
|
||||||
|
if tpu:
|
||||||
|
return tpu
|
||||||
|
|
||||||
|
for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
|
||||||
|
if e in os.environ:
|
||||||
|
return os.environ[e]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudTPUClient(object):
|
||||||
|
"""Client for working with the Cloud TPU API.
|
||||||
|
|
||||||
|
This client is intended to be used for resolving tpu name to ip addresses.
|
||||||
|
|
||||||
|
It's recommended to use this library as a contextlib to utilize all
|
||||||
|
functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
tpu=None,
|
||||||
|
zone=None,
|
||||||
|
project=None,
|
||||||
|
credentials='default',
|
||||||
|
service=None,
|
||||||
|
discovery_url=None):
|
||||||
|
if isinstance(tpu, list):
|
||||||
|
if not tpu:
|
||||||
|
raise ValueError('At least one TPU must be specified.')
|
||||||
|
if len(tpu) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Using multiple TPUs in a single session is not yet implemented')
|
||||||
|
tpu = tpu[0]
|
||||||
|
|
||||||
|
tpu = _get_tpu_name(tpu)
|
||||||
|
|
||||||
|
if tpu is None:
|
||||||
|
raise ValueError('Please provide a TPU Name to connect to.')
|
||||||
|
|
||||||
|
self._tpu = compat.as_text(tpu)
|
||||||
|
|
||||||
|
self._use_api = not tpu.startswith('grpc://')
|
||||||
|
self._service = service
|
||||||
|
|
||||||
|
self._credentials = None
|
||||||
|
self._project = None
|
||||||
|
self._zone = None
|
||||||
|
self._discovery_url = None
|
||||||
|
if self._use_api:
|
||||||
|
if credentials != 'default':
|
||||||
|
self._credentials = credentials
|
||||||
|
# Automaically detect project and zone if unspecified.
|
||||||
|
if project:
|
||||||
|
self._project = project
|
||||||
|
else:
|
||||||
|
self._project = compat.as_str(
|
||||||
|
_request_compute_metadata('project/project-id'))
|
||||||
|
if zone:
|
||||||
|
self._zone = zone
|
||||||
|
else:
|
||||||
|
zone_path = compat.as_str(_request_compute_metadata('instance/zone'))
|
||||||
|
self._zone = zone_path.split('/')[-1]
|
||||||
|
self._discovery_url = _environment_discovery_url() or discovery_url
|
||||||
|
|
||||||
|
def _tpu_service(self):
|
||||||
|
"""Creates a new Cloud TPU API object.
|
||||||
|
|
||||||
|
This works around an issue where the underlying HTTP connection sometimes
|
||||||
|
times out when the script has been running for too long. Other methods in
|
||||||
|
this object call this method to get a new API object whenever they need
|
||||||
|
to communicate with the Cloud API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Google Cloud TPU API object.
|
||||||
|
"""
|
||||||
|
if self._service:
|
||||||
|
return self._service
|
||||||
|
|
||||||
|
credentials = self._credentials
|
||||||
|
if credentials is None or credentials == 'default':
|
||||||
|
credentials = GoogleCredentials.get_application_default()
|
||||||
|
|
||||||
|
if self._discovery_url:
|
||||||
|
return discovery.build(
|
||||||
|
'tpu',
|
||||||
|
'v1',
|
||||||
|
credentials=credentials,
|
||||||
|
discoveryServiceUrl=self._discovery_url,
|
||||||
|
cache_discovery=False)
|
||||||
|
else:
|
||||||
|
return discovery.build(
|
||||||
|
'tpu', 'v1', credentials=credentials, cache_discovery=False)
|
||||||
|
|
||||||
|
def _fetch_cloud_tpu_metadata(self):
|
||||||
|
"""Returns the TPU metadata object from the TPU Get API call."""
|
||||||
|
try:
|
||||||
|
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||||
|
self._project, self._zone, compat.as_text(self._tpu))
|
||||||
|
service = self._tpu_service()
|
||||||
|
r = service.projects().locations().nodes().get(name=full_name)
|
||||||
|
return r.execute()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
|
||||||
|
'doublecheck the tpu argument in the TPUClusterResolver '
|
||||||
|
'constructor. Exception: %s' % (self._tpu, e))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._open = True
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
|
||||||
|
del type, value, traceback
|
||||||
|
|
||||||
|
def recoverable(self):
|
||||||
|
"""Returns true if the TPU is in a state where training should eventually resume.
|
||||||
|
|
||||||
|
If false the TPU is in a unrecoverable state and should be recreated.
|
||||||
|
"""
|
||||||
|
state = self.state()
|
||||||
|
if state and state in ['TERMINATED', 'PREEMPTED']:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def state(self):
|
||||||
|
"""Return state of the TPU."""
|
||||||
|
if self._use_api:
|
||||||
|
metadata = self._fetch_cloud_tpu_metadata()
|
||||||
|
if 'state' in metadata:
|
||||||
|
return metadata['state']
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def api_available(self):
|
||||||
|
"""Return if the Cloud TPU API is available, if not certain features will not work."""
|
||||||
|
return self._use_api
|
||||||
|
|
||||||
|
def name(self):
|
||||||
|
"""Return the name of the tpu, or the ip address if name is not provided."""
|
||||||
|
return self._tpu
|
||||||
|
|
||||||
|
def get_local_ip(self):
|
||||||
|
"""Return the local ip address of the Google Cloud VM the workload is running on."""
|
||||||
|
return _request_compute_metadata('instance/network-interfaces/0/ip')
|
||||||
|
|
||||||
|
def network_endpoints(self):
|
||||||
|
"""Return a list of tpu endpoints."""
|
||||||
|
if not self._use_api:
|
||||||
|
return list(_environment_var_to_network_endpoints(self._tpu))
|
||||||
|
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
if 'state' in response and response['state'] != 'READY':
|
||||||
|
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
||||||
|
(compat.as_text(self._tpu), response['state']))
|
||||||
|
if 'networkEndpoints' in response:
|
||||||
|
return response['networkEndpoints']
|
||||||
|
else:
|
||||||
|
return [{'ipAddress': response['ipAddress'], 'port': response['port']}]
|
@ -0,0 +1,251 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for cloud_tpu_client."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
mock = test.mock
|
||||||
|
|
||||||
|
|
||||||
|
def mock_request_compute_metadata(path):
|
||||||
|
if path == 'project/project-id':
|
||||||
|
return 'test-project'
|
||||||
|
elif path == 'instance/zone':
|
||||||
|
return 'projects/test-project/locations/us-central1-c'
|
||||||
|
elif path == 'instance/network-interfaces/0/ip':
|
||||||
|
return '10.128.1.2'
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
class MockRequestClass(object):
|
||||||
|
|
||||||
|
def __init__(self, name, tpu_map):
|
||||||
|
self._name = name
|
||||||
|
self._tpu_map = tpu_map
|
||||||
|
|
||||||
|
def execute(self):
|
||||||
|
if self._name in self._tpu_map:
|
||||||
|
return self._tpu_map[self._name]
|
||||||
|
else:
|
||||||
|
raise KeyError('Resource %s was not found' % self._name)
|
||||||
|
|
||||||
|
|
||||||
|
class MockNodeClass(object):
|
||||||
|
|
||||||
|
def __init__(self, tpu_map):
|
||||||
|
self._tpu_map = tpu_map
|
||||||
|
|
||||||
|
def get(self, name):
|
||||||
|
return MockRequestClass(name, self._tpu_map)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudTpuClientTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(CloudTpuClientTest, self).setUp()
|
||||||
|
if 'TPU_API_DISCOVERY_URL' in os.environ:
|
||||||
|
del os.environ['TPU_API_DISCOVERY_URL']
|
||||||
|
if 'TPU_NAME' in os.environ:
|
||||||
|
del os.environ['TPU_NAME']
|
||||||
|
|
||||||
|
def mock_service_client(self, tpu_map=None):
|
||||||
|
if tpu_map is None:
|
||||||
|
tpu_map = {}
|
||||||
|
|
||||||
|
mock_locations = mock.MagicMock()
|
||||||
|
mock_locations.nodes.return_value = MockNodeClass(tpu_map)
|
||||||
|
|
||||||
|
mock_project = mock.MagicMock()
|
||||||
|
mock_project.locations.return_value = mock_locations
|
||||||
|
|
||||||
|
mock_client = mock.MagicMock()
|
||||||
|
mock_client.projects.return_value = mock_project
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
def testEnvironmentDiscoveryUrl(self):
|
||||||
|
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||||
|
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||||
|
(cloud_tpu_client._environment_discovery_url()))
|
||||||
|
|
||||||
|
def testEnvironmentVarToNetworkEndpointsSingleIp(self):
|
||||||
|
self.assertEqual(
|
||||||
|
[{'ipAddress': '1.2.3.4', 'port': '1234'}],
|
||||||
|
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||||
|
'1.2.3.4:1234')))
|
||||||
|
|
||||||
|
def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
|
||||||
|
self.assertEqual(
|
||||||
|
[{'ipAddress': '1.2.3.4', 'port': '2000'}],
|
||||||
|
list(
|
||||||
|
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||||
|
'grpc://1.2.3.4:2000')))
|
||||||
|
|
||||||
|
def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
|
||||||
|
self.assertEqual(
|
||||||
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
|
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||||
|
list(
|
||||||
|
cloud_tpu_client._environment_var_to_network_endpoints(
|
||||||
|
'1.2.3.4:2000,5.6.7.8:1234')))
|
||||||
|
|
||||||
|
def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
|
||||||
|
self.assertEqual(
|
||||||
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
|
{'ipAddress': '5.6.7.8', 'port': '1234'}],
|
||||||
|
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||||
|
'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
|
||||||
|
|
||||||
|
def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
|
||||||
|
self.assertEqual(
|
||||||
|
[{'ipAddress': '1.2.3.4', 'port': '2000'},
|
||||||
|
{'ipAddress': '5.6.7.8', 'port': '8470'}],
|
||||||
|
list(cloud_tpu_client._environment_var_to_network_endpoints(
|
||||||
|
'1.2.3.4:2000,grpc://5.6.7.8')))
|
||||||
|
|
||||||
|
def testInitializeNoArguments(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'Please provide a TPU Name to connect to.'):
|
||||||
|
cloud_tpu_client.CloudTPUClient()
|
||||||
|
|
||||||
|
def testInitializeMultiElementTpuArray(self):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
NotImplementedError,
|
||||||
|
'Using multiple TPUs in a single session is not yet implemented'):
|
||||||
|
cloud_tpu_client.CloudTPUClient(tpu=['multiple', 'elements'])
|
||||||
|
|
||||||
|
def assertClientContains(self, client):
|
||||||
|
self.assertEqual('tpu_name', client._tpu)
|
||||||
|
self.assertEqual(True, client._use_api)
|
||||||
|
self.assertEqual(None, client._credentials)
|
||||||
|
self.assertEqual('test-project', client._project)
|
||||||
|
self.assertEqual('us-central1-c', client._zone)
|
||||||
|
self.assertEqual(None, client._discovery_url)
|
||||||
|
self.assertEqual([{
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470'
|
||||||
|
}], client.network_endpoints())
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testInitializeNoArgumentsWithEnvironmentVariable(self):
|
||||||
|
os.environ['TPU_NAME'] = 'tpu_name'
|
||||||
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470',
|
||||||
|
'health': 'HEALTHY'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
|
self.assertClientContains(client)
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testInitializeTpuName(self):
|
||||||
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470',
|
||||||
|
'health': 'HEALTHY'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
|
self.assertClientContains(client)
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testInitializeIpAddress(self):
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||||
|
self.assertEqual('grpc://1.2.3.4:8470', client._tpu)
|
||||||
|
self.assertEqual(False, client._use_api)
|
||||||
|
self.assertEqual(None, client._service)
|
||||||
|
self.assertEqual(None, client._credentials)
|
||||||
|
self.assertEqual(None, client._project)
|
||||||
|
self.assertEqual(None, client._zone)
|
||||||
|
self.assertEqual(None, client._discovery_url)
|
||||||
|
self.assertEqual([{
|
||||||
|
'ipAddress': '1.2.3.4',
|
||||||
|
'port': '8470'
|
||||||
|
}], client.network_endpoints())
|
||||||
|
|
||||||
|
def testInitializeWithoutMetadata(self):
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
tpu='tpu_name', project='project', zone='zone')
|
||||||
|
self.assertEqual('tpu_name', client._tpu)
|
||||||
|
self.assertEqual(True, client._use_api)
|
||||||
|
self.assertEqual(None, client._service)
|
||||||
|
self.assertEqual(None, client._credentials)
|
||||||
|
self.assertEqual('project', client._project)
|
||||||
|
self.assertEqual('zone', client._zone)
|
||||||
|
self.assertEqual(None, client._discovery_url)
|
||||||
|
|
||||||
|
def testRecoverableNoApiAccess(self):
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(tpu='grpc://1.2.3.4:8470')
|
||||||
|
self.assertEqual(True, client.recoverable())
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testRecoverableNoState(self):
|
||||||
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
|
self.assertEqual(True, client.recoverable())
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testRecoverableReady(self):
|
||||||
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470',
|
||||||
|
'state': 'READY',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
|
self.assertEqual(True, client.recoverable())
|
||||||
|
|
||||||
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
|
mock_request_compute_metadata)
|
||||||
|
def testRecoverablePreempted(self):
|
||||||
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
|
||||||
|
'ipAddress': '10.1.2.3',
|
||||||
|
'port': '8470',
|
||||||
|
'state': 'PREEMPTED',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = cloud_tpu_client.CloudTPUClient(
|
||||||
|
tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
|
self.assertEqual(False, client.recoverable())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -19,61 +19,30 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from six.moves import urllib
|
from tensorflow.python.distribute.cluster_resolver.cloud_tpu_client import CloudTPUClient
|
||||||
from six.moves.urllib.error import URLError
|
|
||||||
from six.moves.urllib.request import Request
|
|
||||||
from six.moves.urllib.request import urlopen
|
|
||||||
|
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
|
||||||
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
|
from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
_GOOGLE_API_CLIENT_INSTALLED = True
|
|
||||||
try:
|
|
||||||
from googleapiclient import discovery # pylint: disable=g-import-not-at-top
|
|
||||||
from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top
|
|
||||||
except ImportError:
|
|
||||||
_GOOGLE_API_CLIENT_INSTALLED = False
|
|
||||||
|
|
||||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
def is_running_in_gce():
|
||||||
_ENDPOINTS_SEPARATOR = ','
|
return True
|
||||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
|
||||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
|
||||||
|
|
||||||
_TPU_DEVICE_REGEX = re.compile(
|
_TPU_DEVICE_REGEX = re.compile(
|
||||||
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
|
r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
|
||||||
_TPU_CONN_RETRIES = 120
|
_TPU_CONN_RETRIES = 120
|
||||||
|
|
||||||
_GCE_METADATA_ENDPOINT = 'http://metadata.google.internal'
|
|
||||||
|
|
||||||
DeviceDetails = collections.namedtuple(
|
DeviceDetails = collections.namedtuple(
|
||||||
'DeviceDetails', ['device_map', 'total_cores'])
|
'DeviceDetails', ['device_map', 'total_cores'])
|
||||||
|
|
||||||
|
|
||||||
def is_running_in_gce():
|
|
||||||
"""Checks for GCE presence by attempting to query the metadata service."""
|
|
||||||
try:
|
|
||||||
req = Request(
|
|
||||||
'%s/computeMetadata/v1' % _GCE_METADATA_ENDPOINT,
|
|
||||||
headers={'Metadata-Flavor': 'Google'})
|
|
||||||
resp = urllib.request.urlopen(req, timeout=1)
|
|
||||||
info = resp.info()
|
|
||||||
if 'Metadata-Flavor' in info and info['Metadata-Flavor'] == 'Google':
|
|
||||||
return True
|
|
||||||
except URLError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
@tf_export('distribute.cluster_resolver.TPUClusterResolver')
|
||||||
class TPUClusterResolver(ClusterResolver):
|
class TPUClusterResolver(ClusterResolver):
|
||||||
"""Cluster Resolver for Google Cloud TPUs.
|
"""Cluster Resolver for Google Cloud TPUs.
|
||||||
@ -89,53 +58,6 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
Google internal
|
Google internal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _tpu_service(self):
|
|
||||||
"""Creates a new Cloud TPU API object.
|
|
||||||
|
|
||||||
This works around an issue where the underlying HTTP connection sometimes
|
|
||||||
times out when the script has been running for too long. Other methods in
|
|
||||||
this object call this method to get a new API object whenever they need
|
|
||||||
to communicate with the Cloud API.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Google Cloud TPU API object.
|
|
||||||
"""
|
|
||||||
if self._service:
|
|
||||||
return self._service
|
|
||||||
|
|
||||||
credentials = self._credentials
|
|
||||||
if credentials is None or credentials == 'default':
|
|
||||||
credentials = GoogleCredentials.get_application_default()
|
|
||||||
|
|
||||||
if self._discovery_url:
|
|
||||||
return discovery.build(
|
|
||||||
'tpu',
|
|
||||||
'v1',
|
|
||||||
credentials=credentials,
|
|
||||||
discoveryServiceUrl=self._discovery_url,
|
|
||||||
cache_discovery=False)
|
|
||||||
else:
|
|
||||||
return discovery.build(
|
|
||||||
'tpu', 'v1', credentials=credentials, cache_discovery=False)
|
|
||||||
|
|
||||||
def _request_compute_metadata(self, path):
|
|
||||||
req = Request('%s/computeMetadata/v1/%s' % (_GCE_METADATA_ENDPOINT, path),
|
|
||||||
headers={'Metadata-Flavor': 'Google'})
|
|
||||||
resp = urlopen(req)
|
|
||||||
return compat.as_bytes(resp.read())
|
|
||||||
|
|
||||||
def _is_local_tpu(self):
|
|
||||||
return (
|
|
||||||
self._tpu == compat.as_bytes('') or
|
|
||||||
self._tpu == compat.as_bytes('local'))
|
|
||||||
|
|
||||||
def _should_resolve(self):
|
|
||||||
if isinstance(self._should_resolve_override, bool):
|
|
||||||
return self._should_resolve_override
|
|
||||||
else:
|
|
||||||
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
|
|
||||||
self._is_local_tpu())
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_dict_and_cores(devices):
|
def _get_device_dict_and_cores(devices):
|
||||||
"""Returns a dict of hosts to cores and total cores given devices names.
|
"""Returns a dict of hosts to cores and total cores given devices names.
|
||||||
@ -168,25 +90,6 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
'should never happen. Devices: {}'.format(device_dict))
|
'should never happen. Devices: {}'.format(device_dict))
|
||||||
return num_cores_per_host_set.pop()
|
return num_cores_per_host_set.pop()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _in_gke():
|
|
||||||
"""When running in GKE, the environment variable will be set."""
|
|
||||||
return _GKE_ENV_VARIABLE in os.environ
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _gke_endpoints():
|
|
||||||
return os.environ[_GKE_ENV_VARIABLE]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _env_var_fallback():
|
|
||||||
if _DEFAULT_ENV_VARIABLE in os.environ:
|
|
||||||
return os.environ[_DEFAULT_ENV_VARIABLE]
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _environment_discovery_url():
|
|
||||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tpu=None,
|
tpu=None,
|
||||||
zone=None,
|
zone=None,
|
||||||
@ -204,11 +107,11 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
tpu: A string corresponding to the TPU to use. If the string is an empty
|
||||||
string, the string 'local', or a string that begins with 'grpc://',
|
string, the string 'local', or a string that begins with 'grpc://', then
|
||||||
then it is assumed to not correspond with a Cloud TPU and will
|
it is assumed to not correspond with a Cloud TPU and will instead be
|
||||||
instead be passed as the session master and no ClusterSpec propagation
|
passed as the session master and no ClusterSpec propagation will be
|
||||||
will be done. In the future, this may also support a list of strings
|
done. In the future, this may also support a list of strings when
|
||||||
when multiple Cloud TPUs are used.
|
multiple Cloud TPUs are used.
|
||||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||||
will try to discover from the GCE metadata service.
|
will try to discover from the GCE metadata service.
|
||||||
@ -232,6 +135,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
filled in produce an absolute URL to the discovery document for that
|
filled in produce an absolute URL to the discovery document for that
|
||||||
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
service. The environment variable 'TPU_API_DISCOVERY_URL' will override
|
||||||
this.
|
this.
|
||||||
|
**kwargs: Extra keyword arguments passed to CloudTPUClient.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: If the googleapiclient is not installed.
|
ImportError: If the googleapiclient is not installed.
|
||||||
@ -239,92 +143,32 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
RuntimeError: If an empty TPU name is specified and this is running in a
|
RuntimeError: If an empty TPU name is specified and this is running in a
|
||||||
Google Cloud environment.
|
Google Cloud environment.
|
||||||
"""
|
"""
|
||||||
if isinstance(tpu, list):
|
|
||||||
if not tpu:
|
|
||||||
raise ValueError('At least one TPU must be specified.')
|
|
||||||
if len(tpu) != 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Using multiple TPUs in a single session is not yet implemented')
|
|
||||||
tpu = tpu[0]
|
|
||||||
|
|
||||||
in_gke = self._in_gke()
|
self._cloud_tpu_client = CloudTPUClient(
|
||||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
tpu=tpu,
|
||||||
if tpu is None:
|
zone=zone,
|
||||||
if in_gke:
|
project=project,
|
||||||
tpu = self._gke_endpoints()
|
credentials=credentials,
|
||||||
else:
|
service=service,
|
||||||
tpu = self._env_var_fallback()
|
discovery_url=discovery_url)
|
||||||
|
|
||||||
if tpu is None:
|
|
||||||
raise ValueError('Please provide a TPU Name to connect to.')
|
|
||||||
|
|
||||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
|
||||||
|
|
||||||
# If we are running in Cloud and don't specify a TPU name
|
|
||||||
if is_running_in_gce() and not self._tpu:
|
|
||||||
raise RuntimeError('You need to specify a TPU Name if you are running in '
|
|
||||||
'the Google Cloud environment.')
|
|
||||||
|
|
||||||
|
self._tpu = self._cloud_tpu_client.name()
|
||||||
# By default the task_type is 'worker` and the task_id is 0 (which is the
|
# By default the task_type is 'worker` and the task_id is 0 (which is the
|
||||||
# first worker in the task).
|
# first worker in the task).
|
||||||
self.task_type = job_name
|
self.task_type = job_name
|
||||||
self.task_id = 0
|
self.task_id = 0
|
||||||
|
|
||||||
if self._is_local_tpu():
|
|
||||||
self.rpc_layer = None
|
|
||||||
else:
|
|
||||||
self._environment = ''
|
|
||||||
self.rpc_layer = 'grpc'
|
|
||||||
|
|
||||||
# Setting this overrides the return value of self._should_resolve()
|
|
||||||
self._should_resolve_override = None
|
|
||||||
|
|
||||||
# We strip out the protocol if it is included, and override the
|
|
||||||
# shouldResolve function to never resolve. We are adding the protocol back
|
|
||||||
# in later in self.master().
|
|
||||||
if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
|
|
||||||
tpu = tpu[len(self.rpc_layer + '://'):]
|
|
||||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
|
||||||
self._should_resolve_override = False
|
|
||||||
|
|
||||||
# Whether we should actually attempt to contact Cloud APIs
|
|
||||||
should_resolve = self._should_resolve()
|
|
||||||
|
|
||||||
# We error out if we are in a non-Cloud environment which cannot talk to the
|
|
||||||
# Cloud APIs using the standard class and a special object is not passed in.
|
|
||||||
self._service = service
|
|
||||||
if (self._service is None and should_resolve and
|
|
||||||
not _GOOGLE_API_CLIENT_INSTALLED):
|
|
||||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
|
||||||
'before using the TPU cluster resolver. Execute: '
|
|
||||||
'`pip install --upgrade google-api-python-client '
|
|
||||||
'oauth2client` to install with pip.')
|
|
||||||
|
|
||||||
# We save user-passed credentials, unless the user didn't pass in anything.
|
|
||||||
self._credentials = credentials
|
|
||||||
if (credentials == 'default' and should_resolve and
|
|
||||||
_GOOGLE_API_CLIENT_INSTALLED):
|
|
||||||
self._credentials = None
|
|
||||||
|
|
||||||
# Automatically detect project and zone if unspecified.
|
|
||||||
if not project and should_resolve:
|
|
||||||
project = compat.as_str(
|
|
||||||
self._request_compute_metadata('project/project-id'))
|
|
||||||
if not zone and should_resolve:
|
|
||||||
zone_path = compat.as_str(self._request_compute_metadata('instance/zone'))
|
|
||||||
zone = zone_path.split('/')[-1]
|
|
||||||
self._project = project
|
|
||||||
self._zone = zone
|
|
||||||
|
|
||||||
self._discovery_url = self._environment_discovery_url() or discovery_url
|
|
||||||
|
|
||||||
self._coordinator_name = coordinator_name
|
self._coordinator_name = coordinator_name
|
||||||
if (coordinator_name and not coordinator_address and
|
if (coordinator_name and not coordinator_address):
|
||||||
(should_resolve or in_gke)):
|
|
||||||
self._start_local_server()
|
self._start_local_server()
|
||||||
else:
|
else:
|
||||||
self._coordinator_address = coordinator_address
|
self._coordinator_address = coordinator_address
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._cloud_tpu_client.enter()
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin
|
||||||
|
self._cloud_tpu_client.exit(type, value, traceback)
|
||||||
|
|
||||||
def master(self, task_type=None, task_id=None, rpc_layer=None):
|
def master(self, task_type=None, task_id=None, rpc_layer=None):
|
||||||
"""Get the Master string to be used for the session.
|
"""Get the Master string to be used for the session.
|
||||||
|
|
||||||
@ -350,35 +194,27 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If none of the TPUs specified exists.
|
ValueError: If none of the TPUs specified exists.
|
||||||
"""
|
"""
|
||||||
if self._should_resolve():
|
|
||||||
# We are going to communicate with the Cloud TPU APIs to get a Cluster.
|
cluster_spec = self.cluster_spec()
|
||||||
cluster_spec = self.cluster_spec()
|
if task_type is not None and task_id is not None:
|
||||||
if task_type is not None and task_id is not None:
|
# task_type and task_id is from the function parameter
|
||||||
# task_type and task_id is from the function parameter
|
master = cluster_spec.task_address(task_type, task_id)
|
||||||
master = cluster_spec.task_address(task_type, task_id)
|
elif self.task_type is not None and self.task_id is not None:
|
||||||
elif self.task_type is not None and self.task_id is not None:
|
# task_type and task_id is from the object
|
||||||
# task_type and task_id is from the object
|
master = cluster_spec.task_address(self.task_type, self.task_id)
|
||||||
master = cluster_spec.task_address(self.task_type, self.task_id)
|
|
||||||
else:
|
|
||||||
# by default we take the first item in the cluster with the right name
|
|
||||||
job_tasks = cluster_spec.job_tasks(self.task_type)
|
|
||||||
if not job_tasks:
|
|
||||||
raise ValueError('No TPUs with the specified names exist.')
|
|
||||||
master = job_tasks[0]
|
|
||||||
else:
|
else:
|
||||||
if isinstance(self._tpu, (bytes, bytearray)):
|
# by default we take the first item in the cluster with the right name
|
||||||
master = compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR)[0]
|
job_tasks = cluster_spec.job_tasks(self.task_type)
|
||||||
else:
|
if not job_tasks:
|
||||||
master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
|
raise ValueError('No TPUs with the specified names exist.')
|
||||||
return format_master_url(master, rpc_layer or self.rpc_layer)
|
master = job_tasks[0]
|
||||||
|
return format_master_url(master, 'grpc')
|
||||||
|
|
||||||
def get_master(self):
|
def get_master(self):
|
||||||
return self.master()
|
return self.master()
|
||||||
|
|
||||||
def get_job_name(self):
|
def get_job_name(self):
|
||||||
if ops.executing_eagerly_outside_functions() or self._should_resolve(
|
return self.task_type
|
||||||
) or is_running_in_gce():
|
|
||||||
return self.task_type
|
|
||||||
|
|
||||||
def cluster_spec(self):
|
def cluster_spec(self):
|
||||||
"""Returns a ClusterSpec object based on the latest TPU information.
|
"""Returns a ClusterSpec object based on the latest TPU information.
|
||||||
@ -402,65 +238,20 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
# tasks and
|
# tasks and
|
||||||
# a. Create a ClusterSpec with the coordinator
|
# a. Create a ClusterSpec with the coordinator
|
||||||
# b. Create a ClusterSpec without the coordinator
|
# b. Create a ClusterSpec without the coordinator
|
||||||
# 3. [Other (legacy non-gRPC).] We should return None.
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
if self._should_resolve():
|
network_endpoints = self._cloud_tpu_client.network_endpoints()
|
||||||
# Case 1.
|
worker_list = [
|
||||||
response = self._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
||||||
|
for endpoint in network_endpoints
|
||||||
if 'state' in response and response['state'] != 'READY':
|
]
|
||||||
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
|
cluster_spec = {self.task_type: worker_list}
|
||||||
(compat.as_text(self._tpu), response['state']))
|
|
||||||
|
|
||||||
if 'networkEndpoints' in response:
|
|
||||||
worker_list = [
|
|
||||||
'%s:%s' % (endpoint['ipAddress'], endpoint['port'])
|
|
||||||
for endpoint in response['networkEndpoints']
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Fall back to the deprecated response format
|
|
||||||
instance_url = '%s:%s' % (response['ipAddress'], response['port'])
|
|
||||||
worker_list = [instance_url]
|
|
||||||
|
|
||||||
cluster_spec = {self.task_type: worker_list}
|
|
||||||
else:
|
|
||||||
is_eager = ops.executing_eagerly_outside_functions()
|
|
||||||
if self.rpc_layer is None and not is_eager:
|
|
||||||
# Case 3.
|
|
||||||
return None
|
|
||||||
# Case 2.
|
|
||||||
tpus = []
|
|
||||||
for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR):
|
|
||||||
# We are working around the fact that GKE environment variable that is
|
|
||||||
# supplied to us has the protocol string embedded in it, but we want
|
|
||||||
# to strip it out for the ClusterSpec.
|
|
||||||
if (self.rpc_layer is not None and
|
|
||||||
tpu.startswith(self.rpc_layer + '://')):
|
|
||||||
tpus.append(tpu[len(self.rpc_layer + '://'):])
|
|
||||||
else:
|
|
||||||
tpus.append(tpu)
|
|
||||||
cluster_spec = {self.task_type: tpus}
|
|
||||||
|
|
||||||
if self._coordinator_address:
|
if self._coordinator_address:
|
||||||
# {1, 2}.a
|
# {1, 2}.a
|
||||||
cluster_spec[self._coordinator_name] = [self._coordinator_address]
|
cluster_spec[self._coordinator_name] = [self._coordinator_address]
|
||||||
|
|
||||||
return server_lib.ClusterSpec(cluster_spec)
|
return server_lib.ClusterSpec(cluster_spec)
|
||||||
|
|
||||||
def _fetch_cloud_tpu_metadata(self):
|
|
||||||
"""Returns the TPU metadata object from the TPU Get API call."""
|
|
||||||
try:
|
|
||||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
|
||||||
self._project, self._zone, compat.as_text(self._tpu))
|
|
||||||
service = self._tpu_service()
|
|
||||||
request = service.projects().locations().nodes().get(name=full_name)
|
|
||||||
return request.execute()
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
|
|
||||||
"doublecheck the tpu argument in the TPUClusterResolver "
|
|
||||||
"constructor. Exception: %s" % (self._tpu, e))
|
|
||||||
|
|
||||||
def num_accelerators(self,
|
def num_accelerators(self,
|
||||||
task_type=None,
|
task_type=None,
|
||||||
task_id=None,
|
task_id=None,
|
||||||
@ -510,12 +301,11 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
return self._environment
|
return self._environment
|
||||||
|
|
||||||
def _start_local_server(self):
|
def _start_local_server(self):
|
||||||
address = compat.as_text(
|
address = compat.as_text(self._cloud_tpu_client.get_local_ip())
|
||||||
self._request_compute_metadata('instance/network-interfaces/0/ip'))
|
self._server = server_lib.Server({'local': ['0.0.0.0:0']},
|
||||||
self._server = server_lib.Server(
|
protocol='grpc',
|
||||||
{
|
config=None,
|
||||||
'local': ['0.0.0.0:0']
|
start=True)
|
||||||
}, protocol='grpc', config=None, start=True)
|
|
||||||
# self._server.target is of the form: grpc://ipaddress:port
|
# self._server.target is of the form: grpc://ipaddress:port
|
||||||
target = compat.as_bytes(self._server.target)
|
target = compat.as_bytes(self._server.target)
|
||||||
splits = target.split(compat.as_bytes(':'))
|
splits = target.split(compat.as_bytes(':'))
|
||||||
|
@ -25,6 +25,7 @@ from six.moves.urllib.error import URLError
|
|||||||
|
|
||||||
from tensorflow.python import framework
|
from tensorflow.python import framework
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.distribute.cluster_resolver import cloud_tpu_client
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||||
from tensorflow.python.eager.context import LogicalDevice
|
from tensorflow.python.eager.context import LogicalDevice
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
mock = test.mock
|
mock = test.mock
|
||||||
|
|
||||||
|
|
||||||
@ -58,8 +58,8 @@ class MockNodeClass(object):
|
|||||||
return MockRequestClass(name, self._tpu_map)
|
return MockRequestClass(name, self._tpu_map)
|
||||||
|
|
||||||
|
|
||||||
def mock_request_compute_metadata(cls, *args, **kwargs):
|
def mock_request_compute_metadata(*args, **kwargs):
|
||||||
del cls, kwargs # Unused.
|
del kwargs # Unused.
|
||||||
if args[0] == 'project/project-id':
|
if args[0] == 'project/project-id':
|
||||||
return 'test-project'
|
return 'test-project'
|
||||||
elif args[0] == 'instance/zone':
|
elif args[0] == 'instance/zone':
|
||||||
@ -107,12 +107,12 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self.assertProtoEquals(
|
self.assertProtoEquals(
|
||||||
expected_proto,
|
expected_proto,
|
||||||
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
|
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
|
||||||
self.assertProtoEquals(expected_proto,
|
self.assertProtoEquals(
|
||||||
server_lib.ClusterSpec(
|
expected_proto,
|
||||||
cluster_spec.as_cluster_def()).as_cluster_def())
|
server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
|
||||||
self.assertProtoEquals(expected_proto,
|
self.assertProtoEquals(
|
||||||
server_lib.ClusterSpec(
|
expected_proto,
|
||||||
cluster_spec.as_dict()).as_cluster_def())
|
server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
|
||||||
|
|
||||||
def mock_service_client(self, tpu_map=None):
|
def mock_service_client(self, tpu_map=None):
|
||||||
|
|
||||||
@ -130,26 +130,19 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
|
|
||||||
return mock_client
|
return mock_client
|
||||||
|
|
||||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_running_in_gce)
|
||||||
mock_is_running_in_gce)
|
|
||||||
def testCheckRunningInGceWithNoTpuName(self):
|
def testCheckRunningInGceWithNoTpuName(self):
|
||||||
with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'Please provide a TPU Name to connect to.*'):
|
||||||
resolver.TPUClusterResolver(tpu='')
|
resolver.TPUClusterResolver(tpu='')
|
||||||
|
|
||||||
@mock.patch.object(six.moves.urllib.request,
|
@mock.patch.object(six.moves.urllib.request, 'urlopen',
|
||||||
'urlopen',
|
|
||||||
mock_running_in_gce_urlopen)
|
mock_running_in_gce_urlopen)
|
||||||
def testIsRunningInGce(self):
|
def testIsRunningInGce(self):
|
||||||
self.assertTrue(resolver.is_running_in_gce())
|
self.assertTrue(resolver.is_running_in_gce())
|
||||||
|
|
||||||
@mock.patch.object(six.moves.urllib.request,
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
'urlopen',
|
mock_request_compute_metadata)
|
||||||
mock_not_running_in_gce_urlopen)
|
|
||||||
def testIsNotRunningInGce(self):
|
|
||||||
self.assertFalse(resolver.is_running_in_gce())
|
|
||||||
|
|
||||||
@mock.patch.object(resolver.TPUClusterResolver,
|
|
||||||
'_request_compute_metadata', mock_request_compute_metadata)
|
|
||||||
def testRetrieveProjectAndZoneFromMetadata(self):
|
def testRetrieveProjectAndZoneFromMetadata(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
@ -181,8 +174,8 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||||
|
|
||||||
@mock.patch.object(resolver.TPUClusterResolver,
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
'_request_compute_metadata', mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
@ -207,8 +200,8 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||||
|
|
||||||
@mock.patch.object(resolver.TPUClusterResolver,
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
'_request_compute_metadata', mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testNotReadyCloudTpu(self):
|
def testNotReadyCloudTpu(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
@ -306,8 +299,8 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
|
||||||
|
|
||||||
@mock.patch.object(resolver.TPUClusterResolver,
|
@mock.patch.object(cloud_tpu_client, '_request_compute_metadata',
|
||||||
'_request_compute_metadata', mock_request_compute_metadata)
|
mock_request_compute_metadata)
|
||||||
def testPodResolution(self):
|
def testPodResolution(self):
|
||||||
tpu_map = {
|
tpu_map = {
|
||||||
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
@ -425,17 +418,10 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
coordinator_name=None,
|
coordinator_name=None,
|
||||||
credentials=None,
|
credentials=None,
|
||||||
service=self.mock_service_client(tpu_map={}))
|
service=self.mock_service_client(tpu_map={}))
|
||||||
self.assertEqual(should_resolve, cluster_resolver._should_resolve(),
|
self.assertEqual(should_resolve,
|
||||||
|
cluster_resolver._cloud_tpu_client.api_available(),
|
||||||
"TPU: '%s'" % tpu)
|
"TPU: '%s'" % tpu)
|
||||||
|
|
||||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
|
||||||
mock_is_not_running_in_gce)
|
|
||||||
def testShouldResolveNoName(self):
|
|
||||||
self.verifyShouldResolve('', False)
|
|
||||||
|
|
||||||
def testShouldResolveLocal(self):
|
|
||||||
self.verifyShouldResolve('local', False)
|
|
||||||
|
|
||||||
def testShouldResolveGrpc(self):
|
def testShouldResolveGrpc(self):
|
||||||
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
||||||
|
|
||||||
@ -461,11 +447,6 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||||
|
|
||||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||||
self.assertTrue(resolver.TPUClusterResolver._in_gke())
|
|
||||||
self.assertEqual(
|
|
||||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
|
||||||
compat.as_bytes(
|
|
||||||
resolver.TPUClusterResolver._gke_endpoints()))
|
|
||||||
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver()
|
cluster_resolver = resolver.TPUClusterResolver()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -489,15 +470,6 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
'grpc://10.120.27.8:8470')
|
'grpc://10.120.27.8:8470')
|
||||||
|
|
||||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||||
self.assertTrue(resolver.TPUClusterResolver._in_gke())
|
|
||||||
self.assertEqual(
|
|
||||||
compat.as_bytes('grpc://10.120.27.5:8470,'
|
|
||||||
'grpc://10.120.27.6:8470,'
|
|
||||||
'grpc://10.120.27.7:8470,'
|
|
||||||
'grpc://10.120.27.8:8470'),
|
|
||||||
compat.as_bytes(
|
|
||||||
resolver.TPUClusterResolver._gke_endpoints()))
|
|
||||||
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver()
|
cluster_resolver = resolver.TPUClusterResolver()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||||
@ -516,17 +488,9 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
|
|
||||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||||
|
|
||||||
def testEnvironmentDiscoveryUrl(self):
|
def testRpcDetectionForGrpcString(self):
|
||||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
|
||||||
self.assertEqual(
|
|
||||||
'https://{api}.internal/{apiVersion}',
|
|
||||||
(resolver.TPUClusterResolver._environment_discovery_url()))
|
|
||||||
|
|
||||||
def testEnvironmentAndRpcDetectionForGrpcString(self):
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(
|
cluster_resolver = resolver.TPUClusterResolver(
|
||||||
tpu='grpc://10.1.2.3:8470')
|
tpu='grpc://10.1.2.3:8470')
|
||||||
self.assertEqual(cluster_resolver.environment, '')
|
|
||||||
self.assertEqual(cluster_resolver.rpc_layer, 'grpc')
|
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
|
||||||
|
|
||||||
def testOverrideTaskTypeAndIndexAndGetMaster(self):
|
def testOverrideTaskTypeAndIndexAndGetMaster(self):
|
||||||
@ -569,13 +533,8 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
cluster_resolver.task_id = 3
|
cluster_resolver.task_id = 3
|
||||||
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
|
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
cluster_resolver.master(
|
|
||||||
task_type='worker', task_id=2, rpc_layer='test'),
|
|
||||||
'test://10.2.3.6:8470')
|
|
||||||
|
|
||||||
def testGetDeviceDictAndCoresWithTPUs(self):
|
def testGetDeviceDictAndCoresWithTPUs(self):
|
||||||
device_names = [
|
devices = [
|
||||||
'/job:tpu_worker/task:0/device:TPU:0',
|
'/job:tpu_worker/task:0/device:TPU:0',
|
||||||
'/job:tpu_worker/task:1/device:TPU:1',
|
'/job:tpu_worker/task:1/device:TPU:1',
|
||||||
'/job:tpu_worker/task:2/device:TPU:0',
|
'/job:tpu_worker/task:2/device:TPU:0',
|
||||||
@ -586,8 +545,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
'/job:tpu_worker/task:3/device:TPU:5',
|
'/job:tpu_worker/task:3/device:TPU:5',
|
||||||
]
|
]
|
||||||
device_list = [
|
device_list = [
|
||||||
session._DeviceAttributes(
|
session._DeviceAttributes(name, 'TPU', 1024, 0) for name in devices
|
||||||
name, 'TPU', 1024, 0) for name in device_names
|
|
||||||
]
|
]
|
||||||
|
|
||||||
device_details = resolver.TPUClusterResolver._get_device_dict_and_cores(
|
device_details = resolver.TPUClusterResolver._get_device_dict_and_cores(
|
||||||
@ -600,7 +558,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
'3': ['1', '5']})
|
'3': ['1', '5']})
|
||||||
|
|
||||||
def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
|
def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
|
||||||
device_names = [
|
devices = [
|
||||||
'/job:tpu_worker/task:0/device:CPU:0',
|
'/job:tpu_worker/task:0/device:CPU:0',
|
||||||
'/job:tpu_worker/task:1/device:CPU:0',
|
'/job:tpu_worker/task:1/device:CPU:0',
|
||||||
'/job:tpu_worker/task:2/device:CPU:0',
|
'/job:tpu_worker/task:2/device:CPU:0',
|
||||||
@ -611,8 +569,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
'/job:tpu_worker/task:3/device:GPU:1',
|
'/job:tpu_worker/task:3/device:GPU:1',
|
||||||
]
|
]
|
||||||
device_list = [
|
device_list = [
|
||||||
session._DeviceAttributes(
|
session._DeviceAttributes(name, 'XLA', 1024, 0) for name in devices
|
||||||
name, 'XLA', 1024, 0) for name in device_names
|
|
||||||
]
|
]
|
||||||
|
|
||||||
device_dict, num_cores =\
|
device_dict, num_cores =\
|
||||||
@ -639,8 +596,7 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
|
|
||||||
@mock.patch.object(framework.config, 'list_logical_devices')
|
@mock.patch.object(framework.config, 'list_logical_devices')
|
||||||
@mock.patch.object(session.BaseSession, 'list_devices')
|
@mock.patch.object(session.BaseSession, 'list_devices')
|
||||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
|
||||||
mock_is_not_running_in_gce)
|
|
||||||
def testNumAcceleratorsSuccess(self, mock_list_devices,
|
def testNumAcceleratorsSuccess(self, mock_list_devices,
|
||||||
mock_eager_list_devices):
|
mock_eager_list_devices):
|
||||||
devices = [
|
devices = [
|
||||||
@ -660,16 +616,73 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
mock_eager_list_devices.return_value = devices
|
mock_eager_list_devices.return_value = devices
|
||||||
mock_list_devices.return_value = device_list
|
mock_list_devices.return_value = device_list
|
||||||
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='')
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
|
'health':
|
||||||
|
'HEALTHY',
|
||||||
|
'networkEndpoints': [
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.4',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.5',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.6',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.7',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_resolver = resolver.TPUClusterResolver(
|
||||||
|
project='test-project',
|
||||||
|
zone='us-central1-c',
|
||||||
|
tpu='test-tpu-1',
|
||||||
|
service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
|
self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
|
||||||
|
|
||||||
@mock.patch.object(framework.config, 'list_logical_devices')
|
@mock.patch.object(framework.config, 'list_logical_devices')
|
||||||
@mock.patch.object(session.BaseSession, 'list_devices')
|
@mock.patch.object(session.BaseSession, 'list_devices')
|
||||||
@mock.patch.object(resolver, 'is_running_in_gce',
|
@mock.patch.object(resolver, 'is_running_in_gce', mock_is_not_running_in_gce)
|
||||||
mock_is_not_running_in_gce)
|
|
||||||
def testNumAcceleratorsRetryFailure(self, mock_list_devices,
|
def testNumAcceleratorsRetryFailure(self, mock_list_devices,
|
||||||
mock_eager_list_devices):
|
mock_eager_list_devices):
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='')
|
tpu_map = {
|
||||||
|
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
|
||||||
|
'health':
|
||||||
|
'HEALTHY',
|
||||||
|
'networkEndpoints': [
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.4',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.5',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.6',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'ipAddress': '10.2.3.7',
|
||||||
|
'port': 8470,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_resolver = resolver.TPUClusterResolver(
|
||||||
|
project='test-project',
|
||||||
|
zone='us-central1-c',
|
||||||
|
tpu='test-tpu-1',
|
||||||
|
service=self.mock_service_client(tpu_map=tpu_map))
|
||||||
mock_list_devices.side_effect = errors.DeadlineExceededError(
|
mock_list_devices.side_effect = errors.DeadlineExceededError(
|
||||||
None, None, 'timeout')
|
None, None, 'timeout')
|
||||||
mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
|
mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
|
||||||
|
@ -84,20 +84,10 @@ class _TPUPollingThread(threading.Thread):
|
|||||||
return
|
return
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
response = self._cluster._fetch_cloud_tpu_metadata() # pylint: disable=protected-access
|
recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
|
||||||
logging.warning(
|
if not recoverable:
|
||||||
'TPUPollingThread found TPU %s in state %s, and health %s.',
|
|
||||||
self._cluster._tpu, response['state'], # pylint: disable=protected-access
|
|
||||||
response.get('health', 'UNKNOWN'))
|
|
||||||
|
|
||||||
if 'state' in response and response['state'] in [
|
|
||||||
'TERMINATED', 'PREEMPTED'
|
|
||||||
]:
|
|
||||||
logging.warning(
|
logging.warning(
|
||||||
'TPU node %s reached an unrecoverable state %s, '
|
'TPUPollingThread found TPU %s in state %s',
|
||||||
'terminating training.',
|
self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access
|
||||||
self._cluster._tpu, # pylint: disable=protected-access
|
|
||||||
response['state'])
|
|
||||||
os._exit(1) # pylint: disable=protected-access
|
os._exit(1) # pylint: disable=protected-access
|
||||||
|
|
||||||
time.sleep(self._interval)
|
time.sleep(self._interval)
|
||||||
|
Loading…
Reference in New Issue
Block a user