Further clean up TPUClusterResolver
PiperOrigin-RevId: 265969707
This commit is contained in:
parent
05cf54667e
commit
3af471cd27
@ -124,20 +124,17 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
resp = urlopen(req)
|
resp = urlopen(req)
|
||||||
return compat.as_bytes(resp.read())
|
return compat.as_bytes(resp.read())
|
||||||
|
|
||||||
def _is_google_environment(self):
|
def _is_local_tpu(self):
|
||||||
return (
|
return (
|
||||||
self._tpu == compat.as_bytes('') or
|
self._tpu == compat.as_bytes('') or
|
||||||
self._tpu == compat.as_bytes('local') or
|
self._tpu == compat.as_bytes('local'))
|
||||||
self._tpu.startswith(compat.as_bytes('localhost:')) or
|
|
||||||
self._tpu.startswith(compat.as_bytes('/bns')) or
|
|
||||||
self._tpu.startswith(compat.as_bytes('uptc://')))
|
|
||||||
|
|
||||||
def _should_resolve(self):
|
def _should_resolve(self):
|
||||||
if isinstance(self._should_resolve_override, bool):
|
if isinstance(self._should_resolve_override, bool):
|
||||||
return self._should_resolve_override
|
return self._should_resolve_override
|
||||||
else:
|
else:
|
||||||
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
|
return not (self._tpu.startswith(compat.as_bytes('grpc://')) or
|
||||||
self._is_google_environment())
|
self._is_local_tpu())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_dict_and_cores(devices):
|
def _get_device_dict_and_cores(devices):
|
||||||
@ -207,8 +204,8 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
tpu: A string corresponding to the TPU to use. If the string is an empty
|
||||||
string, the string 'local', or a string that begins with 'grpc://' or
|
string, the string 'local', or a string that begins with 'grpc://',
|
||||||
'/bns', then it is assumed to not correspond with a Cloud TPU and will
|
then it is assumed to not correspond with a Cloud TPU and will
|
||||||
instead be passed as the session master and no ClusterSpec propagation
|
instead be passed as the session master and no ClusterSpec propagation
|
||||||
will be done. In the future, this may also support a list of strings
|
will be done. In the future, this may also support a list of strings
|
||||||
when multiple Cloud TPUs are used.
|
when multiple Cloud TPUs are used.
|
||||||
@ -273,29 +270,8 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
self.task_type = job_name
|
self.task_type = job_name
|
||||||
self.task_id = 0
|
self.task_id = 0
|
||||||
|
|
||||||
# TODO(bfontain): Remove Google specific code from this class.
|
if self._is_local_tpu():
|
||||||
if self._is_google_environment():
|
|
||||||
self._environment = 'google'
|
|
||||||
self.rpc_layer = None
|
self.rpc_layer = None
|
||||||
|
|
||||||
# TODO(rsopher): remove this logic when possible
|
|
||||||
if self._tpu and self._tpu.startswith(compat.as_bytes('/bns')):
|
|
||||||
bns_and_port = self._tpu.rsplit(compat.as_bytes(':'), 1)
|
|
||||||
if len(bns_and_port) == 2:
|
|
||||||
try:
|
|
||||||
int(bns_and_port[1])
|
|
||||||
except ValueError:
|
|
||||||
# Leave named ports.
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Strip numerical ports.
|
|
||||||
self._tpu = bns_and_port[0]
|
|
||||||
|
|
||||||
# Remove '.brain' suffix.
|
|
||||||
# TODO(b/139700237): Support bns address with named port.
|
|
||||||
if ops.executing_eagerly_outside_functions() and self._tpu.endswith(
|
|
||||||
compat.as_bytes('.brain')):
|
|
||||||
self._tpu = self._tpu[:-6]
|
|
||||||
else:
|
else:
|
||||||
self._environment = ''
|
self._environment = ''
|
||||||
self.rpc_layer = 'grpc'
|
self.rpc_layer = 'grpc'
|
||||||
|
@ -27,7 +27,6 @@ from tensorflow.python import eager
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
@ -436,15 +435,9 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
def testShouldResolveLocal(self):
|
def testShouldResolveLocal(self):
|
||||||
self.verifyShouldResolve('local', False)
|
self.verifyShouldResolve('local', False)
|
||||||
|
|
||||||
def testShouldResolveLocalhost(self):
|
|
||||||
self.verifyShouldResolve('localhost:12345', False)
|
|
||||||
|
|
||||||
def testShouldResolveGrpc(self):
|
def testShouldResolveGrpc(self):
|
||||||
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
|
||||||
|
|
||||||
def testShouldResolveBns(self):
|
|
||||||
self.verifyShouldResolve('/bns/foo/bar', False)
|
|
||||||
|
|
||||||
def testShouldResolveName(self):
|
def testShouldResolveName(self):
|
||||||
self.verifyShouldResolve('mytpu', True)
|
self.verifyShouldResolve('mytpu', True)
|
||||||
|
|
||||||
@ -455,20 +448,13 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
self.verifyShouldResolve('grpctpu', True)
|
self.verifyShouldResolve('grpctpu', True)
|
||||||
|
|
||||||
def testNoCallComputeMetadata(self):
|
def testNoCallComputeMetadata(self):
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar')
|
cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470')
|
||||||
self.assertEqual('/bns/foo/bar', cluster_resolver.master())
|
self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master())
|
||||||
if ops.executing_eagerly_outside_functions():
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
server_lib.ClusterSpec({
|
server_lib.ClusterSpec({
|
||||||
'worker': ['/bns/foo/bar']
|
'worker': ['10.1.2.3:8470']
|
||||||
}).as_dict(),
|
}).as_dict(),
|
||||||
cluster_resolver.cluster_spec().as_dict())
|
cluster_resolver.cluster_spec().as_dict())
|
||||||
else:
|
|
||||||
self.assertEqual(None, cluster_resolver.cluster_spec())
|
|
||||||
|
|
||||||
def testLocalhostMaster(self):
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345')
|
|
||||||
self.assertEqual('localhost:12345', cluster_resolver.master())
|
|
||||||
|
|
||||||
def testGkeEnvironmentForDonut(self):
|
def testGkeEnvironmentForDonut(self):
|
||||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||||
@ -535,25 +521,6 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
'https://{api}.internal/{apiVersion}',
|
'https://{api}.internal/{apiVersion}',
|
||||||
(resolver.TPUClusterResolver._environment_discovery_url()))
|
(resolver.TPUClusterResolver._environment_discovery_url()))
|
||||||
|
|
||||||
def testEnvironmentAndRpcDetectionForGoogle(self):
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
|
|
||||||
self.assertEqual(cluster_resolver.environment, 'google')
|
|
||||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
|
||||||
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
|
|
||||||
|
|
||||||
def testEnvironmentAndRpcDetectionForGoogleNumericalPort(self):
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:1234')
|
|
||||||
self.assertEqual(cluster_resolver.environment, 'google')
|
|
||||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
|
||||||
self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef'))
|
|
||||||
|
|
||||||
def testEnvironmentAndRpcDetectionForGoogleNamedPort(self):
|
|
||||||
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef:port')
|
|
||||||
self.assertEqual(cluster_resolver.environment, 'google')
|
|
||||||
self.assertEqual(cluster_resolver.rpc_layer, None)
|
|
||||||
self.assertEqual(cluster_resolver._tpu,
|
|
||||||
compat.as_bytes('/bns/ab/cd/ef:port'))
|
|
||||||
|
|
||||||
def testEnvironmentAndRpcDetectionForGrpcString(self):
|
def testEnvironmentAndRpcDetectionForGrpcString(self):
|
||||||
cluster_resolver = resolver.TPUClusterResolver(
|
cluster_resolver = resolver.TPUClusterResolver(
|
||||||
tpu='grpc://10.1.2.3:8470')
|
tpu='grpc://10.1.2.3:8470')
|
||||||
|
Loading…
Reference in New Issue
Block a user