From 89c0134fd811a12d8100da385cb9a218247f0933 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 11 Nov 2018 20:11:46 -0800 Subject: [PATCH] Fixes an issue in TPUClusterResolver where an underlying HTTP connection times out, resulting in user visible errors PiperOrigin-RevId: 221031372 --- .../python/training/tpu_cluster_resolver.py | 84 ++++++++++++------- .../training/tpu_cluster_resolver_test.py | 4 +- 2 files changed, 55 insertions(+), 33 deletions(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c4ac9d07001..0157d2a0c86 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _tpuService(self): + """Creates a new Cloud TPU API object. + + This works around an issue where the underlying HTTP connection sometimes + times out when the script has been running for too long. Other methods in + this object calls this method to get a new API object whenever they need + to communicate with the Cloud API. + + Returns: + A Google Cloud TPU API object. + """ + if self._service: + return self._service + + credentials = self.credentials + if credentials is None or credentials == 'default': + credentials = GoogleCredentials.get_application_default() + + if self._discovery_url: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials, + discoveryServiceUrl=self._discovery_url) + else: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials) + def _requestComputeMetadata(self, path): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) @@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver): return None @staticmethod - def _discoveryUrl(): + def _environmentDiscoveryUrl(): return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, @@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver): self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name - self._credentials = credentials + # Whether we should actually attempt to contact Cloud APIs should_resolve = self._shouldResolve() + # We error out if we are in a non-Cloud environment which cannot talk to the + # Cloud APIs using the standard class and a special object is not passed in. + self._service = service + if (self._service is None and should_resolve and + not _GOOGLE_API_CLIENT_INSTALLED): + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + # We save user-passed credentials, unless the user didn't pass in anything. + self._credentials = credentials + if (credentials == 'default' and should_resolve and + _GOOGLE_API_CLIENT_INSTALLED): + self._credentials = None + + # Automatically detect project and zone if unspecified. if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] - self._project = project self._zone = zone - if credentials == 'default' and should_resolve: - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None and should_resolve: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - final_discovery_url = self._discoveryUrl() or discovery_url - if final_discovery_url: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials, - discoveryServiceUrl=final_discovery_url) - else: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) - else: - self._service = service + self._discovery_url = self._environmentDiscoveryUrl() or discovery_url self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and (should_resolve or - in_gke): + if (coordinator_name and not coordinator_address and + (should_resolve or in_gke)): self._start_local_server() else: self._coordinator_address = coordinator_address @@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) + service = self._tpuService() + request = service.projects().locations().nodes().get(name=full_name) response = request.execute() if 'state' in response and response['state'] != 'READY': diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index ad4f6432630..478c82967ba 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase): del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - def testDiscoveryUrl(self): + def testEnvironmentDiscoveryUrl(self): os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' self.assertEqual('https://{api}.internal/{apiVersion}', - TPUClusterResolver._discoveryUrl()) + TPUClusterResolver._environmentDiscoveryUrl()) if __name__ == '__main__': test.main()