Fixes an issue in TPUClusterResolver where an underlying HTTP connection times out, resulting in user visible errors
PiperOrigin-RevId: 221031372
This commit is contained in:
parent
19f97dc4bd
commit
89c0134fd8
@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
Cloud Platform project.
|
Cloud Platform project.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _tpuService(self):
|
||||||
|
"""Creates a new Cloud TPU API object.
|
||||||
|
|
||||||
|
This works around an issue where the underlying HTTP connection sometimes
|
||||||
|
times out when the script has been running for too long. Other methods in
|
||||||
|
this object calls this method to get a new API object whenever they need
|
||||||
|
to communicate with the Cloud API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Google Cloud TPU API object.
|
||||||
|
"""
|
||||||
|
if self._service:
|
||||||
|
return self._service
|
||||||
|
|
||||||
|
credentials = self.credentials
|
||||||
|
if credentials is None or credentials == 'default':
|
||||||
|
credentials = GoogleCredentials.get_application_default()
|
||||||
|
|
||||||
|
if self._discovery_url:
|
||||||
|
return discovery.build(
|
||||||
|
'tpu', 'v1alpha1',
|
||||||
|
credentials=credentials,
|
||||||
|
discoveryServiceUrl=self._discovery_url)
|
||||||
|
else:
|
||||||
|
return discovery.build(
|
||||||
|
'tpu', 'v1alpha1',
|
||||||
|
credentials=credentials)
|
||||||
|
|
||||||
def _requestComputeMetadata(self, path):
|
def _requestComputeMetadata(self, path):
|
||||||
req = Request('http://metadata/computeMetadata/v1/%s' % path,
|
req = Request('http://metadata/computeMetadata/v1/%s' % path,
|
||||||
headers={'Metadata-Flavor': 'Google'})
|
headers={'Metadata-Flavor': 'Google'})
|
||||||
@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _discoveryUrl():
|
def _environmentDiscoveryUrl():
|
||||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
|
|
||||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||||
self._job_name = job_name
|
self._job_name = job_name
|
||||||
self._credentials = credentials
|
|
||||||
|
|
||||||
|
# Whether we should actually attempt to contact Cloud APIs
|
||||||
should_resolve = self._shouldResolve()
|
should_resolve = self._shouldResolve()
|
||||||
|
|
||||||
|
# We error out if we are in a non-Cloud environment which cannot talk to the
|
||||||
|
# Cloud APIs using the standard class and a special object is not passed in.
|
||||||
|
self._service = service
|
||||||
|
if (self._service is None and should_resolve and
|
||||||
|
not _GOOGLE_API_CLIENT_INSTALLED):
|
||||||
|
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||||
|
'before using the TPU cluster resolver. Execute: '
|
||||||
|
'`pip install --upgrade google-api-python-client` '
|
||||||
|
'and `pip install --upgrade oauth2client` to '
|
||||||
|
'install with pip.')
|
||||||
|
|
||||||
|
# We save user-passed credentials, unless the user didn't pass in anything.
|
||||||
|
self._credentials = credentials
|
||||||
|
if (credentials == 'default' and should_resolve and
|
||||||
|
_GOOGLE_API_CLIENT_INSTALLED):
|
||||||
|
self._credentials = None
|
||||||
|
|
||||||
|
# Automatically detect project and zone if unspecified.
|
||||||
if not project and should_resolve:
|
if not project and should_resolve:
|
||||||
project = compat.as_str(
|
project = compat.as_str(
|
||||||
self._requestComputeMetadata('project/project-id'))
|
self._requestComputeMetadata('project/project-id'))
|
||||||
|
|
||||||
if not zone and should_resolve:
|
if not zone and should_resolve:
|
||||||
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
|
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
|
||||||
zone = zone_path.split('/')[-1]
|
zone = zone_path.split('/')[-1]
|
||||||
|
|
||||||
self._project = project
|
self._project = project
|
||||||
self._zone = zone
|
self._zone = zone
|
||||||
|
|
||||||
if credentials == 'default' and should_resolve:
|
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
|
||||||
if _GOOGLE_API_CLIENT_INSTALLED:
|
|
||||||
self._credentials = GoogleCredentials.get_application_default()
|
|
||||||
|
|
||||||
if service is None and should_resolve:
|
|
||||||
if not _GOOGLE_API_CLIENT_INSTALLED:
|
|
||||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
|
||||||
'before using the TPU cluster resolver. Execute: '
|
|
||||||
'`pip install --upgrade google-api-python-client` '
|
|
||||||
'and `pip install --upgrade oauth2client` to '
|
|
||||||
'install with pip.')
|
|
||||||
|
|
||||||
final_discovery_url = self._discoveryUrl() or discovery_url
|
|
||||||
if final_discovery_url:
|
|
||||||
self._service = discovery.build(
|
|
||||||
'tpu', 'v1alpha1',
|
|
||||||
credentials=self._credentials,
|
|
||||||
discoveryServiceUrl=final_discovery_url)
|
|
||||||
else:
|
|
||||||
self._service = discovery.build(
|
|
||||||
'tpu', 'v1alpha1',
|
|
||||||
credentials=self._credentials)
|
|
||||||
else:
|
|
||||||
self._service = service
|
|
||||||
|
|
||||||
self._coordinator_name = coordinator_name
|
self._coordinator_name = coordinator_name
|
||||||
if coordinator_name and not coordinator_address and (should_resolve or
|
if (coordinator_name and not coordinator_address and
|
||||||
in_gke):
|
(should_resolve or in_gke)):
|
||||||
self._start_local_server()
|
self._start_local_server()
|
||||||
else:
|
else:
|
||||||
self._coordinator_address = coordinator_address
|
self._coordinator_address = coordinator_address
|
||||||
@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
# Case 1.
|
# Case 1.
|
||||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||||
self._project, self._zone, compat.as_text(self._tpu))
|
self._project, self._zone, compat.as_text(self._tpu))
|
||||||
request = self._service.projects().locations().nodes().get(name=full_name)
|
service = self._tpuService()
|
||||||
|
request = service.projects().locations().nodes().get(name=full_name)
|
||||||
response = request.execute()
|
response = request.execute()
|
||||||
|
|
||||||
if 'state' in response and response['state'] != 'READY':
|
if 'state' in response and response['state'] != 'READY':
|
||||||
|
@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
|
|
||||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||||
|
|
||||||
def testDiscoveryUrl(self):
|
def testEnvironmentDiscoveryUrl(self):
|
||||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||||
self.assertEqual('https://{api}.internal/{apiVersion}',
|
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||||
TPUClusterResolver._discoveryUrl())
|
TPUClusterResolver._environmentDiscoveryUrl())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user