diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py index eb701eff327..fb3911aa928 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py @@ -124,20 +124,17 @@ class TPUClusterResolver(ClusterResolver): resp = urlopen(req) return compat.as_bytes(resp.read()) - def _is_google_environment(self): + def _is_local_tpu(self): return ( self._tpu == compat.as_bytes('') or - self._tpu == compat.as_bytes('local') or - self._tpu.startswith(compat.as_bytes('localhost:')) or - self._tpu.startswith(compat.as_bytes('/bns')) or - self._tpu.startswith(compat.as_bytes('uptc://'))) + self._tpu == compat.as_bytes('local')) def _should_resolve(self): if isinstance(self._should_resolve_override, bool): return self._should_resolve_override else: return not (self._tpu.startswith(compat.as_bytes('grpc://')) or - self._is_google_environment()) + self._is_local_tpu()) @staticmethod def _get_device_dict_and_cores(devices): @@ -207,11 +204,11 @@ class TPUClusterResolver(ClusterResolver): Args: tpu: A string corresponding to the TPU to use. If the string is an empty - string, the string 'local', or a string that begins with 'grpc://' or - '/bns', then it is assumed to not correspond with a Cloud TPU and will - instead be passed as the session master and no ClusterSpec propagation - will be done. In the future, this may also support a list of strings - when multiple Cloud TPUs are used. + string, the string 'local', or a string that begins with 'grpc://', + then it is assumed to not correspond with a Cloud TPU and will + instead be passed as the session master and no ClusterSpec propagation + will be done. In the future, this may also support a list of strings + when multiple Cloud TPUs are used. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. @@ -273,29 +270,8 @@ class TPUClusterResolver(ClusterResolver): self.task_type = job_name self.task_id = 0 - # TODO(bfontain): Remove Google specific code from this class. - if self._is_google_environment(): - self._environment = 'google' + if self._is_local_tpu(): self.rpc_layer = None - - # TODO(rsopher): remove this logic when possible - if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')): - bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1) - if len(bns_and_port) == 2: - try: - int(bns_and_port[1]) - except ValueError: - # Leave named ports. - pass - else: - # Strip numerical ports. - self._tpu = bns_and_port[0] - - # Remove '.brain' suffix. - # TODO(b/139700237): Support bns address with named port. - if ops.executing_eagerly_outside_functions() and self._tpu.endswith( - compat.as_bytes('.brain')): - self._tpu = self._tpu[:-6] else: self._environment = '' self.rpc_layer = 'grpc' diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py index 9cfd63ae8cf..83ded5c18b6 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver_test.py @@ -27,7 +27,6 @@ from tensorflow.python import eager from tensorflow.python.client import session from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -436,15 +435,9 @@ class TPUClusterResolverTest(test.TestCase): def testShouldResolveLocal(self): self.verifyShouldResolve('local', False) - def testShouldResolveLocalhost(self): - self.verifyShouldResolve('localhost:12345', False) - def testShouldResolveGrpc(self): self.verifyShouldResolve('grpc://10.1.2.3:8470', False) - def testShouldResolveBns(self): - self.verifyShouldResolve('/bns/foo/bar', False) - def testShouldResolveName(self): self.verifyShouldResolve('mytpu', True) @@ -455,20 +448,13 @@ class TPUClusterResolverTest(test.TestCase): self.verifyShouldResolve('grpctpu', True) def testNoCallComputeMetadata(self): - cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar') - self.assertEqual('/bns/foo/bar', cluster_resolver.master()) - if ops.executing_eagerly_outside_functions(): - self.assertEqual( - server_lib.ClusterSpec({ - 'worker': ['/bns/foo/bar'] - }).as_dict(), - cluster_resolver.cluster_spec().as_dict()) - else: - self.assertEqual(None, cluster_resolver.cluster_spec()) - - def testLocalhostMaster(self): - cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345') - self.assertEqual('localhost:12345', cluster_resolver.master()) + cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470') + self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master()) + self.assertEqual( + server_lib.ClusterSpec({ + 'worker': ['10.1.2.3:8470'] + }).as_dict(), + cluster_resolver.cluster_spec().as_dict()) def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' @@ -535,25 +521,6 @@ class TPUClusterResolverTest(test.TestCase): 'https://{api}.internal/{apiVersion}', (resolver.TPUClusterResolver._environment_discovery_url())) - def testEnvironmentAndRpcDetectionForGoogle(self): - cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef') - self.assertEqual(cluster_resolver.environment, 'google') - self.assertEqual(cluster_resolver.rpc_layer, None) - self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef')) - - def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self): - cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234') - self.assertEqual(cluster_resolver.environment, 'google') - self.assertEqual(cluster_resolver.rpc_layer, None) - self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef')) - - def testEnvironmentAndRpcDetectionForGoogleNamedPort(self): - cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port') - self.assertEqual(cluster_resolver.environment, 'google') - self.assertEqual(cluster_resolver.rpc_layer, None) - self.assertEqual(cluster_resolver._tpu, - compat.as_bytes('/bns/ab/cd/ef:port')) - def testEnvironmentAndRpcDetectionForGrpcString(self): cluster_resolver = resolver.TPUClusterResolver( tpu='grpc://10.1.2.3:8470')