Add a `TPUClusterResolver.connect` API to simplify TPU initialization.

RELNOTES=Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
PiperOrigin-RevId: 317439811
Change-Id: I2f1a944f3c440356b21da27a72855c969f1c3b3b
This commit is contained in:
Ruoxin Sang 2020-06-20 00:06:21 -07:00 committed by TensorFlower Gardener
parent 27cb9aa834
commit ae78d9c315
6 changed files with 74 additions and 15 deletions

View File

@ -52,16 +52,59 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
"""Cluster Resolver for Google Cloud TPUs.
This is an implementation of cluster resolvers for the Google Cloud TPU
service. As Cloud TPUs are in alpha, you will need to specify a API definition
file for this to consume, in addition to a list of Cloud TPUs in your Google
Cloud Platform project.
service.
TPUClusterResolver supports the following distinct environments:
Google Compute Engine
Google Kubernetes Engine
Google internal
It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
Cloud TPUs.
"""
@staticmethod
def connect(tpu=None,
zone=None,
project=None):
"""Initializes TPU and returns a TPUClusterResolver.
This API will connect to remote TPU cluster and initialize the TPU
hardwares. Example usage:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
... tpu='')
It can be viewed as convenient wrapper of the following code:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
Args:
tpu: A string corresponding to the TPU to use. It can be the TPU name or
TPU worker gRPC address. If not set, it will try automatically resolve
the TPU address on Cloud TPUs.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
project: Name of the GCP project containing Cloud TPUs. If omitted or
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
Returns:
An instance of TPUClusterResolver object.
Raises:
NotFoundError: If no TPU devices found in eager mode.
"""
resolver = TPUClusterResolver(tpu, zone, project)
from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top
remote.connect_to_cluster(resolver)
from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top
tpu_strategy_util.initialize_tpu_system(resolver)
return resolver
@staticmethod
def _get_device_dict_and_cores(devices):
"""Returns a dict of hosts to cores and total cores given devices names.
@ -110,12 +153,9 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
for the IP addresses and ports of each Cloud TPU listed.
Args:
tpu: A string corresponding to the TPU to use. If the string is an empty
string, the string 'local', or a string that begins with 'grpc://', then
it is assumed to not correspond with a Cloud TPU and will instead be
passed as the session master and no ClusterSpec propagation will be
done. In the future, this may also support a list of strings when
multiple Cloud TPUs are used.
tpu: A string corresponding to the TPU to use. It can be the TPU name or
TPU worker gRPC address. If not set, it will try automatically resolve
the TPU address on Cloud TPUs.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.

View File

@ -338,7 +338,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
if context.executing_eagerly():
# In async remote eager, we want to sync the exectors before exiting the
# program.
atexit.register(context.async_wait)
def async_wait():
if context.context()._context_handle is not None: # pylint: disable=protected-access
context.async_wait()
atexit.register(async_wait)
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
# a dataset to multiple devices exists.

View File

@ -24,6 +24,7 @@ from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import T
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology
@ -47,8 +48,8 @@ def initialize_tpu_system(cluster_resolver=None):
The tf.tpu.Topology object for the topology of the TPU cluster.
Raises:
RuntimeError: If no TPU devices found for eager execution or if run in a
tf.function.
RuntimeError: If running inside a tf.function.
NotFoundError: If no TPU devices found in eager mode.
"""
job = None
if cluster_resolver is None:
@ -93,8 +94,15 @@ def initialize_tpu_system(cluster_resolver=None):
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
# devices available.
try:
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
output = _tpu_init_fn()
context.async_wait()
except errors.InvalidArgumentError as e:
raise errors.NotFoundError(
None, None,
"TPUs not found in the cluster. Failed in initialization: "
+ str(e))
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")

View File

@ -15,6 +15,10 @@ tf_class {
name: "cluster_spec"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "connect"
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_job_name"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -15,6 +15,10 @@ tf_class {
name: "cluster_spec"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "connect"
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_job_name"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"

View File

@ -11,7 +11,7 @@ package(
exports_files(["LICENSE"])
tpu_module = "tpu.,distribute.tpu_strategy"
tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss"
py_library(
name = "tf_doctest_lib",