diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py index 943b736fde4..e42420ec644 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py @@ -52,16 +52,59 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver): """Cluster Resolver for Google Cloud TPUs. This is an implementation of cluster resolvers for the Google Cloud TPU - service. As Cloud TPUs are in alpha, you will need to specify a API definition - file for this to consume, in addition to a list of Cloud TPUs in your Google - Cloud Platform project. + service. TPUClusterResolver supports the following distinct environments: Google Compute Engine Google Kubernetes Engine Google internal + + It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on + Cloud TPUs. """ + @staticmethod + def connect(tpu=None, + zone=None, + project=None): + """Initializes TPU and returns a TPUClusterResolver. + + This API will connect to remote TPU cluster and initialize the TPU + hardwares. Example usage: + + >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( + ... tpu='') + + It can be viewed as convenient wrapper of the following code: + + >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + >>> tf.config.experimental_connect_to_cluster(resolver) + >>> tf.tpu.experimental.initialize_tpu_system(resolver) + + Args: + tpu: A string corresponding to the TPU to use. It can be the TPU name or + TPU worker gRPC address. If not set, it will try automatically resolve + the TPU address on Cloud TPUs. + zone: Zone where the TPUs are located. If omitted or empty, we will assume + that the zone of the TPU is the same as the zone of the GCE VM, which we + will try to discover from the GCE metadata service. + project: Name of the GCP project containing Cloud TPUs. If omitted or + empty, we will try to discover the project name of the GCE VM from the + GCE metadata service. + + Returns: + An instance of TPUClusterResolver object. + + Raises: + NotFoundError: If no TPU devices found in eager mode. + """ + resolver = TPUClusterResolver(tpu, zone, project) + from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top + remote.connect_to_cluster(resolver) + from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top + tpu_strategy_util.initialize_tpu_system(resolver) + return resolver + @staticmethod def _get_device_dict_and_cores(devices): """Returns a dict of hosts to cores and total cores given devices names. @@ -110,12 +153,9 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - tpu: A string corresponding to the TPU to use. If the string is an empty - string, the string 'local', or a string that begins with 'grpc://', then - it is assumed to not correspond with a Cloud TPU and will instead be - passed as the session master and no ClusterSpec propagation will be - done. In the future, this may also support a list of strings when - multiple Cloud TPUs are used. + tpu: A string corresponding to the TPU to use. It can be the TPU name or + TPU worker gRPC address. If not set, it will try automatically resolve + the TPU address on Cloud TPUs. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 7e8f5b97e7e..e2e75260593 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -338,7 +338,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): if context.executing_eagerly(): # In async remote eager, we want to sync the exectors before exiting the # program. - atexit.register(context.async_wait) + def async_wait(): + if context.context()._context_handle is not None: # pylint: disable=protected-access + context.async_wait() + atexit.register(async_wait) # TODO(bfontain): Remove once a proper dataset API exists for prefetching # a dataset to multiple devices exists. diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index 543c91167cd..c315d7c5e1b 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -24,6 +24,7 @@ from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import T from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import device +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import topology @@ -47,8 +48,8 @@ def initialize_tpu_system(cluster_resolver=None): The tf.tpu.Topology object for the topology of the TPU cluster. Raises: - RuntimeError: If no TPU devices found for eager execution or if run in a - tf.function. + RuntimeError: If running inside a tf.function. + NotFoundError: If no TPU devices found in eager mode. """ job = None if cluster_resolver is None: @@ -93,8 +94,15 @@ def initialize_tpu_system(cluster_resolver=None): # The TPU_SYSTEM device must match the device used in tpu.initialize_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. - with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access - output = _tpu_init_fn() + try: + with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access + output = _tpu_init_fn() + context.async_wait() + except errors.InvalidArgumentError as e: + raise errors.NotFoundError( + None, None, + "TPUs not found in the cluster. Failed in initialization: " + + str(e)) # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt index 658212aca5e..8eee489df93 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt @@ -15,6 +15,10 @@ tf_class { name: "cluster_spec" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "connect" + argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "get_job_name" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt index 658212aca5e..8eee489df93 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt @@ -15,6 +15,10 @@ tf_class { name: "cluster_spec" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "connect" + argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "get_job_name" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index f347cab91bb..d2f5469a55c 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -11,7 +11,7 @@ package( exports_files(["LICENSE"]) -tpu_module = "tpu.,distribute.tpu_strategy" +tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss" py_library( name = "tf_doctest_lib",