Fixes an issue in TPUClusterResolver where an underlying HTTP connection times out, resulting in user visible errors
PiperOrigin-RevId: 221031372
This commit is contained in:
parent
19f97dc4bd
commit
89c0134fd8
@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver):
|
||||
Cloud Platform project.
|
||||
"""
|
||||
|
||||
def _tpuService(self):
|
||||
"""Creates a new Cloud TPU API object.
|
||||
|
||||
This works around an issue where the underlying HTTP connection sometimes
|
||||
times out when the script has been running for too long. Other methods in
|
||||
this object calls this method to get a new API object whenever they need
|
||||
to communicate with the Cloud API.
|
||||
|
||||
Returns:
|
||||
A Google Cloud TPU API object.
|
||||
"""
|
||||
if self._service:
|
||||
return self._service
|
||||
|
||||
credentials = self.credentials
|
||||
if credentials is None or credentials == 'default':
|
||||
credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if self._discovery_url:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials,
|
||||
discoveryServiceUrl=self._discovery_url)
|
||||
else:
|
||||
return discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=credentials)
|
||||
|
||||
def _requestComputeMetadata(self, path):
|
||||
req = Request('http://metadata/computeMetadata/v1/%s' % path,
|
||||
headers={'Metadata-Flavor': 'Google'})
|
||||
@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _discoveryUrl():
|
||||
def _environmentDiscoveryUrl():
|
||||
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
|
||||
|
||||
def __init__(self,
|
||||
@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver):
|
||||
|
||||
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
|
||||
self._job_name = job_name
|
||||
self._credentials = credentials
|
||||
|
||||
# Whether we should actually attempt to contact Cloud APIs
|
||||
should_resolve = self._shouldResolve()
|
||||
|
||||
# We error out if we are in a non-Cloud environment which cannot talk to the
|
||||
# Cloud APIs using the standard class and a special object is not passed in.
|
||||
self._service = service
|
||||
if (self._service is None and should_resolve and
|
||||
not _GOOGLE_API_CLIENT_INSTALLED):
|
||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||
'before using the TPU cluster resolver. Execute: '
|
||||
'`pip install --upgrade google-api-python-client` '
|
||||
'and `pip install --upgrade oauth2client` to '
|
||||
'install with pip.')
|
||||
|
||||
# We save user-passed credentials, unless the user didn't pass in anything.
|
||||
self._credentials = credentials
|
||||
if (credentials == 'default' and should_resolve and
|
||||
_GOOGLE_API_CLIENT_INSTALLED):
|
||||
self._credentials = None
|
||||
|
||||
# Automatically detect project and zone if unspecified.
|
||||
if not project and should_resolve:
|
||||
project = compat.as_str(
|
||||
self._requestComputeMetadata('project/project-id'))
|
||||
|
||||
if not zone and should_resolve:
|
||||
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
|
||||
zone = zone_path.split('/')[-1]
|
||||
|
||||
self._project = project
|
||||
self._zone = zone
|
||||
|
||||
if credentials == 'default' and should_resolve:
|
||||
if _GOOGLE_API_CLIENT_INSTALLED:
|
||||
self._credentials = GoogleCredentials.get_application_default()
|
||||
|
||||
if service is None and should_resolve:
|
||||
if not _GOOGLE_API_CLIENT_INSTALLED:
|
||||
raise ImportError('googleapiclient and oauth2client must be installed '
|
||||
'before using the TPU cluster resolver. Execute: '
|
||||
'`pip install --upgrade google-api-python-client` '
|
||||
'and `pip install --upgrade oauth2client` to '
|
||||
'install with pip.')
|
||||
|
||||
final_discovery_url = self._discoveryUrl() or discovery_url
|
||||
if final_discovery_url:
|
||||
self._service = discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=self._credentials,
|
||||
discoveryServiceUrl=final_discovery_url)
|
||||
else:
|
||||
self._service = discovery.build(
|
||||
'tpu', 'v1alpha1',
|
||||
credentials=self._credentials)
|
||||
else:
|
||||
self._service = service
|
||||
self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
|
||||
|
||||
self._coordinator_name = coordinator_name
|
||||
if coordinator_name and not coordinator_address and (should_resolve or
|
||||
in_gke):
|
||||
if (coordinator_name and not coordinator_address and
|
||||
(should_resolve or in_gke)):
|
||||
self._start_local_server()
|
||||
else:
|
||||
self._coordinator_address = coordinator_address
|
||||
@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver):
|
||||
# Case 1.
|
||||
full_name = 'projects/%s/locations/%s/nodes/%s' % (
|
||||
self._project, self._zone, compat.as_text(self._tpu))
|
||||
request = self._service.projects().locations().nodes().get(name=full_name)
|
||||
service = self._tpuService()
|
||||
request = service.projects().locations().nodes().get(name=full_name)
|
||||
response = request.execute()
|
||||
|
||||
if 'state' in response and response['state'] != 'READY':
|
||||
|
@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
|
||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||
|
||||
def testDiscoveryUrl(self):
|
||||
def testEnvironmentDiscoveryUrl(self):
|
||||
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
|
||||
self.assertEqual('https://{api}.internal/{apiVersion}',
|
||||
TPUClusterResolver._discoveryUrl())
|
||||
TPUClusterResolver._environmentDiscoveryUrl())
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user