Further clean up TPUClusterResolver

PiperOrigin-RevId: 265969707
This commit is contained in:
Bruce Fontaine 2019-08-28 12:39:08 -07:00 committed by TensorFlower Gardener
parent 05cf54667e
commit 3af471cd27
2 changed files with 16 additions and 73 deletions

View File

@ -124,20 +124,17 @@ class TPUClusterResolver(ClusterResolver):
resp = urlopen(req)
return compat.as_bytes(resp.read())
def _is_google_environment(self):
def _is_local_tpu(self):
return (
self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('/bns')) or
self._tpu.startswith(compat.as_bytes('uptc://')))
self._tpu == compat.as_bytes('local'))
def _should_resolve(self):
if isinstance(self._should_resolve_override, bool):
return self._should_resolve_override
else:
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
self._is_google_environment())
self._is_local_tpu())
@staticmethod
def _get_device_dict_and_cores(devices):
@ -207,8 +204,8 @@ class TPUClusterResolver(ClusterResolver):
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://' or
'/bns', then it is assumed to not correspond with a Cloud TPU and will
string, the string 'local', or a string that begins with 'grpc://',
then it is assumed to not correspond with a Cloud TPU and will
instead be passed as the session master and no ClusterSpec propagation
will be done. In the future, this may also support a list of strings
when multiple Cloud TPUs are used.
@ -273,29 +270,8 @@ class TPUClusterResolver(ClusterResolver):
self.task_type = job_name
self.task_id = 0
# TODO(bfontain): Remove Google specific code from this class.
if self._is_google_environment():
self._environment = 'google'
if self._is_local_tpu():
self.rpc_layer = None
# TODO(rsopher): remove this logic when possible
if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')):
bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1)
if len(bns_and_port) == 2:
try:
int(bns_and_port[1])
except ValueError:
# Leave named ports.
pass
else:
# Strip numerical ports.
self._tpu = bns_and_port[0]
# Remove '.brain' suffix.
# TODO(b/139700237): Support bns address with named port.
if ops.executing_eagerly_outside_functions() and self._tpu.endswith(
compat.as_bytes('.brain')):
self._tpu = self._tpu[:-6]
else:
self._environment = ''
self.rpc_layer = 'grpc'

View File

@ -27,7 +27,6 @@ from tensorflow.python import eager
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -436,15 +435,9 @@ class TPUClusterResolverTest(test.TestCase):
def testShouldResolveLocal(self):
self.verifyShouldResolve('local', False)
def testShouldResolveLocalhost(self):
self.verifyShouldResolve('localhost:12345', False)
def testShouldResolveGrpc(self):
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
def testShouldResolveBns(self):
self.verifyShouldResolve('/bns/foo/bar', False)
def testShouldResolveName(self):
self.verifyShouldResolve('mytpu', True)
@ -455,20 +448,13 @@ class TPUClusterResolverTest(test.TestCase):
self.verifyShouldResolve('grpctpu', True)
def testNoCallComputeMetadata(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar')
self.assertEqual('/bns/foo/bar', cluster_resolver.master())
if ops.executing_eagerly_outside_functions():
cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master())
self.assertEqual(
server_lib.ClusterSpec({
'worker': ['/bns/foo/bar']
'worker': ['10.1.2.3:8470']
}).as_dict(),
cluster_resolver.cluster_spec().as_dict())
else:
self.assertEqual(None, cluster_resolver.cluster_spec())
def testLocalhostMaster(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345')
self.assertEqual('localhost:12345', cluster_resolver.master())
def testGkeEnvironmentForDonut(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
@ -535,25 +521,6 @@ class TPUClusterResolverTest(test.TestCase):
'https://{api}.internal/{apiVersion}',
(resolver.TPUClusterResolver._environment_discovery_url()))
def testEnvironmentAndRpcDetectionForGoogle(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
def testEnvironmentAndRpcDetectionForGoogleNamedPort(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu,
compat.as_bytes('/bns/ab/cd/ef:port'))
def testEnvironmentAndRpcDetectionForGrpcString(self):
cluster_resolver = resolver.TPUClusterResolver(
tpu='grpc://10.1.2.3:8470')