Fixes an issue in TPUClusterResolver where an underlying HTTP connection times out, resulting in user visible errors

PiperOrigin-RevId: 221031372
This commit is contained in:
Frank Chen 2018-11-11 20:11:46 -08:00 committed by TensorFlower Gardener
parent 19f97dc4bd
commit 89c0134fd8
2 changed files with 55 additions and 33 deletions

View File

@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver):
Cloud Platform project. Cloud Platform project.
""" """
def _tpuService(self):
"""Creates a new Cloud TPU API object.
This works around an issue where the underlying HTTP connection sometimes
times out when the script has been running for too long. Other methods in
this object calls this method to get a new API object whenever they need
to communicate with the Cloud API.
Returns:
A Google Cloud TPU API object.
"""
if self._service:
return self._service
credentials = self.credentials
if credentials is None or credentials == 'default':
credentials = GoogleCredentials.get_application_default()
if self._discovery_url:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials,
discoveryServiceUrl=self._discovery_url)
else:
return discovery.build(
'tpu', 'v1alpha1',
credentials=credentials)
def _requestComputeMetadata(self, path): def _requestComputeMetadata(self, path):
req = Request('http://metadata/computeMetadata/v1/%s' % path, req = Request('http://metadata/computeMetadata/v1/%s' % path,
headers={'Metadata-Flavor': 'Google'}) headers={'Metadata-Flavor': 'Google'})
@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver):
return None return None
@staticmethod @staticmethod
def _discoveryUrl(): def _environmentDiscoveryUrl():
return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
def __init__(self, def __init__(self,
@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver):
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name self._job_name = job_name
self._credentials = credentials
# Whether we should actually attempt to contact Cloud APIs
should_resolve = self._shouldResolve() should_resolve = self._shouldResolve()
# We error out if we are in a non-Cloud environment which cannot talk to the
# Cloud APIs using the standard class and a special object is not passed in.
self._service = service
if (self._service is None and should_resolve and
not _GOOGLE_API_CLIENT_INSTALLED):
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the TPU cluster resolver. Execute: '
'`pip install --upgrade google-api-python-client` '
'and `pip install --upgrade oauth2client` to '
'install with pip.')
# We save user-passed credentials, unless the user didn't pass in anything.
self._credentials = credentials
if (credentials == 'default' and should_resolve and
_GOOGLE_API_CLIENT_INSTALLED):
self._credentials = None
# Automatically detect project and zone if unspecified.
if not project and should_resolve: if not project and should_resolve:
project = compat.as_str( project = compat.as_str(
self._requestComputeMetadata('project/project-id')) self._requestComputeMetadata('project/project-id'))
if not zone and should_resolve: if not zone and should_resolve:
zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
zone = zone_path.split('/')[-1] zone = zone_path.split('/')[-1]
self._project = project self._project = project
self._zone = zone self._zone = zone
if credentials == 'default' and should_resolve: self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
if _GOOGLE_API_CLIENT_INSTALLED:
self._credentials = GoogleCredentials.get_application_default()
if service is None and should_resolve:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient and oauth2client must be installed '
'before using the TPU cluster resolver. Execute: '
'`pip install --upgrade google-api-python-client` '
'and `pip install --upgrade oauth2client` to '
'install with pip.')
final_discovery_url = self._discoveryUrl() or discovery_url
if final_discovery_url:
self._service = discovery.build(
'tpu', 'v1alpha1',
credentials=self._credentials,
discoveryServiceUrl=final_discovery_url)
else:
self._service = discovery.build(
'tpu', 'v1alpha1',
credentials=self._credentials)
else:
self._service = service
self._coordinator_name = coordinator_name self._coordinator_name = coordinator_name
if coordinator_name and not coordinator_address and (should_resolve or if (coordinator_name and not coordinator_address and
in_gke): (should_resolve or in_gke)):
self._start_local_server() self._start_local_server()
else: else:
self._coordinator_address = coordinator_address self._coordinator_address = coordinator_address
@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver):
# Case 1. # Case 1.
full_name = 'projects/%s/locations/%s/nodes/%s' % ( full_name = 'projects/%s/locations/%s/nodes/%s' % (
self._project, self._zone, compat.as_text(self._tpu)) self._project, self._zone, compat.as_text(self._tpu))
request = self._service.projects().locations().nodes().get(name=full_name) service = self._tpuService()
request = service.projects().locations().nodes().get(name=full_name)
response = request.execute() response = request.execute()
if 'state' in response and response['state'] != 'READY': if 'state' in response and response['state'] != 'READY':

View File

@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase):
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testDiscoveryUrl(self): def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual('https://{api}.internal/{apiVersion}', self.assertEqual('https://{api}.internal/{apiVersion}',
TPUClusterResolver._discoveryUrl()) TPUClusterResolver._environmentDiscoveryUrl())
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()