Further clean up TPUClusterResolver
PiperOrigin-RevId: 265969707
This commit is contained in:
parent
05cf54667e
commit
3af471cd27
@ -124,20 +124,17 @@ class TPUClusterResolver(ClusterResolver):
|
||||
resp = urlopen(req)
|
||||
return compat.as_bytes(resp.read())
|
||||
|
||||
def _is_google_environment(self):
|
||||
def _is_local_tpu(self):
|
||||
return (
|
||||
self._tpu == compat.as_bytes('') or
|
||||
self._tpu == compat.as_bytes('local') or
|
||||
self._tpu.startswith(compat.as_bytes('localhost:')) or
|
||||
self._tpu.startswith(compat.as_bytes('/bns')) or
|
||||
self._tpu.startswith(compat.as_bytes('uptc://')))
|
||||
self._tpu == compat.as_bytes('local'))
|
||||
|
||||
def _should_resolve(self):
|
||||
if isinstance(self._should_resolve_override, bool):
|
||||
return self._should_resolve_override
|
||||
else:
|
||||
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
|
||||
self._is_google_environment())
|
||||
self._is_local_tpu())
|
||||
|
||||
@staticmethod
|
||||
def _get_device_dict_and_cores(devices):
|
||||
@ -207,8 +204,8 @@ class TPUClusterResolver(ClusterResolver):
|
||||
|
||||
Args:
|
||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
||||
string, the string 'local', or a string that begins with 'grpc://' or
|
||||
'/bns', then it is assumed to not correspond with a Cloud TPU and will
|
||||
string, the string 'local', or a string that begins with 'grpc://',
|
||||
then it is assumed to not correspond with a Cloud TPU and will
|
||||
instead be passed as the session master and no ClusterSpec propagation
|
||||
will be done. In the future, this may also support a list of strings
|
||||
when multiple Cloud TPUs are used.
|
||||
@ -273,29 +270,8 @@ class TPUClusterResolver(ClusterResolver):
|
||||
self.task_type = job_name
|
||||
self.task_id = 0
|
||||
|
||||
# TODO(bfontain): Remove Google specific code from this class.
|
||||
if self._is_google_environment():
|
||||
self._environment = 'google'
|
||||
if self._is_local_tpu():
|
||||
self.rpc_layer = None
|
||||
|
||||
# TODO(rsopher): remove this logic when possible
|
||||
if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')):
|
||||
bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1)
|
||||
if len(bns_and_port) == 2:
|
||||
try:
|
||||
int(bns_and_port[1])
|
||||
except ValueError:
|
||||
# Leave named ports.
|
||||
pass
|
||||
else:
|
||||
# Strip numerical ports.
|
||||
self._tpu = bns_and_port[0]
|
||||
|
||||
# Remove '.brain' suffix.
|
||||
# TODO(b/139700237): Support bns address with named port.
|
||||
if ops.executing_eagerly_outside_functions() and self._tpu.endswith(
|
||||
compat.as_bytes('.brain')):
|
||||
self._tpu = self._tpu[:-6]
|
||||
else:
|
||||
self._environment = ''
|
||||
self.rpc_layer = 'grpc'
|
||||
|
@ -27,7 +27,6 @@ from tensorflow.python import eager
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
@ -436,15 +435,9 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
def testShouldResolveLocal(self):
|
||||
self.verifyShouldResolve('local', False)
|
||||
|
||||
def testShouldResolveLocalhost(self):
|
||||
self.verifyShouldResolve('localhost:12345', False)
|
||||
|
||||
def testShouldResolveGrpc(self):
|
||||
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
||||
|
||||
def testShouldResolveBns(self):
|
||||
self.verifyShouldResolve('/bns/foo/bar', False)
|
||||
|
||||
def testShouldResolveName(self):
|
||||
self.verifyShouldResolve('mytpu', True)
|
||||
|
||||
@ -455,20 +448,13 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
self.verifyShouldResolve('grpctpu', True)
|
||||
|
||||
def testNoCallComputeMetadata(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar')
|
||||
self.assertEqual('/bns/foo/bar', cluster_resolver.master())
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
|
||||
self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master())
|
||||
self.assertEqual(
|
||||
server_lib.ClusterSpec({
|
||||
'worker': ['/bns/foo/bar']
|
||||
'worker': ['10.1.2.3:8470']
|
||||
}).as_dict(),
|
||||
cluster_resolver.cluster_spec().as_dict())
|
||||
else:
|
||||
self.assertEqual(None, cluster_resolver.cluster_spec())
|
||||
|
||||
def testLocalhostMaster(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345')
|
||||
self.assertEqual('localhost:12345', cluster_resolver.master())
|
||||
|
||||
def testGkeEnvironmentForDonut(self):
|
||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||
@ -535,25 +521,6 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
'https://{api}.internal/{apiVersion}',
|
||||
(resolver.TPUClusterResolver._environment_discovery_url()))
|
||||
|
||||
def testEnvironmentAndRpcDetectionForGoogle(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
|
||||
self.assertEqual(cluster_resolver.environment, 'google')
|
||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
||||
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
|
||||
|
||||
def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234')
|
||||
self.assertEqual(cluster_resolver.environment, 'google')
|
||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
||||
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
|
||||
|
||||
def testEnvironmentAndRpcDetectionForGoogleNamedPort(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port')
|
||||
self.assertEqual(cluster_resolver.environment, 'google')
|
||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
||||
self.assertEqual(cluster_resolver._tpu,
|
||||
compat.as_bytes('/bns/ab/cd/ef:port'))
|
||||
|
||||
def testEnvironmentAndRpcDetectionForGrpcString(self):
|
||||
cluster_resolver = resolver.TPUClusterResolver(
|
||||
tpu='grpc://10.1.2.3:8470')
|
||||
|
Loading…
Reference in New Issue
Block a user