Further clean up TPUClusterResolver

PiperOrigin-RevId: 265969707
This commit is contained in:
Bruce Fontaine 2019-08-28 12:39:08 -07:00 committed by TensorFlower Gardener
parent 05cf54667e
commit 3af471cd27
2 changed files with 16 additions and 73 deletions

View File

@ -124,20 +124,17 @@ class TPUClusterResolver(ClusterResolver):
resp = urlopen(req)
return compat.as_bytes(resp.read())
def _is_google_environment(self):
def _is_local_tpu(self):
return (
self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('/bns')) or
self._tpu.startswith(compat.as_bytes('uptc://')))
self._tpu == compat.as_bytes('local'))
def _should_resolve(self):
if isinstance(self._should_resolve_override, bool):
return self._should_resolve_override
else:
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
self._is_google_environment())
self._is_local_tpu())
@staticmethod
def _get_device_dict_and_cores(devices):
@ -207,11 +204,11 @@ class TPUClusterResolver(ClusterResolver):
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://' or
'/bns', then it is assumed to not correspond with a Cloud TPU and will
instead be passed as the session master and no ClusterSpec propagation
will be done. In the future, this may also support a list of strings
when multiple Cloud TPUs are used.
string, the string 'local', or a string that begins with 'grpc://',
then it is assumed to not correspond with a Cloud TPU and will
instead be passed as the session master and no ClusterSpec propagation
will be done. In the future, this may also support a list of strings
when multiple Cloud TPUs are used.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
@ -273,29 +270,8 @@ class TPUClusterResolver(ClusterResolver):
self.task_type = job_name
self.task_id = 0
# TODO(bfontain): Remove Google specific code from this class.
if self._is_google_environment():
self._environment = 'google'
if self._is_local_tpu():
self.rpc_layer = None
# TODO(rsopher): remove this logic when possible
if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')):
bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1)
if len(bns_and_port) == 2:
try:
int(bns_and_port[1])
except ValueError:
# Leave named ports.
pass
else:
# Strip numerical ports.
self._tpu = bns_and_port[0]
# Remove '.brain' suffix.
# TODO(b/139700237): Support bns address with named port.
if ops.executing_eagerly_outside_functions() and self._tpu.endswith(
compat.as_bytes('.brain')):
self._tpu = self._tpu[:-6]
else:
self._environment = ''
self.rpc_layer = 'grpc'

View File

@ -27,7 +27,6 @@ from tensorflow.python import eager
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -436,15 +435,9 @@ class TPUClusterResolverTest(test.TestCase):
def testShouldResolveLocal(self):
self.verifyShouldResolve('local', False)
def testShouldResolveLocalhost(self):
self.verifyShouldResolve('localhost:12345', False)
def testShouldResolveGrpc(self):
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
def testShouldResolveBns(self):
self.verifyShouldResolve('/bns/foo/bar', False)
def testShouldResolveName(self):
self.verifyShouldResolve('mytpu', True)
@ -455,20 +448,13 @@ class TPUClusterResolverTest(test.TestCase):
self.verifyShouldResolve('grpctpu', True)
def testNoCallComputeMetadata(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar')
self.assertEqual('/bns/foo/bar', cluster_resolver.master())
if ops.executing_eagerly_outside_functions():
self.assertEqual(
server_lib.ClusterSpec({
'worker': ['/bns/foo/bar']
}).as_dict(),
cluster_resolver.cluster_spec().as_dict())
else:
self.assertEqual(None, cluster_resolver.cluster_spec())
def testLocalhostMaster(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345')
self.assertEqual('localhost:12345', cluster_resolver.master())
cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master())
self.assertEqual(
server_lib.ClusterSpec({
'worker': ['10.1.2.3:8470']
}).as_dict(),
cluster_resolver.cluster_spec().as_dict())
def testGkeEnvironmentForDonut(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
@ -535,25 +521,6 @@ class TPUClusterResolverTest(test.TestCase):
'https://{api}.internal/{apiVersion}',
(resolver.TPUClusterResolver._environment_discovery_url()))
def testEnvironmentAndRpcDetectionForGoogle(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
def testEnvironmentAndRpcDetectionForGoogleNamedPort(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
self.assertEqual(cluster_resolver._tpu,
compat.as_bytes('/bns/ab/cd/ef:port'))
def testEnvironmentAndRpcDetectionForGrpcString(self):
cluster_resolver = resolver.TPUClusterResolver(
tpu='grpc://10.1.2.3:8470')