Add a `TPUClusterResolver.connect` API to simplify TPU initialization.
RELNOTES=Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. PiperOrigin-RevId: 317439811 Change-Id: I2f1a944f3c440356b21da27a72855c969f1c3b3b
This commit is contained in:
parent
27cb9aa834
commit
ae78d9c315
|
@ -52,16 +52,59 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
|||
"""Cluster Resolver for Google Cloud TPUs.
|
||||
|
||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||
service. As Cloud TPUs are in alpha, you will need to specify a API definition
|
||||
file for this to consume, in addition to a list of Cloud TPUs in your Google
|
||||
Cloud Platform project.
|
||||
service.
|
||||
|
||||
TPUClusterResolver supports the following distinct environments:
|
||||
Google Compute Engine
|
||||
Google Kubernetes Engine
|
||||
Google internal
|
||||
|
||||
It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
|
||||
Cloud TPUs.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def connect(tpu=None,
|
||||
zone=None,
|
||||
project=None):
|
||||
"""Initializes TPU and returns a TPUClusterResolver.
|
||||
|
||||
This API will connect to remote TPU cluster and initialize the TPU
|
||||
hardwares. Example usage:
|
||||
|
||||
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
|
||||
... tpu='')
|
||||
|
||||
It can be viewed as convenient wrapper of the following code:
|
||||
|
||||
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||
>>> tf.config.experimental_connect_to_cluster(resolver)
|
||||
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
|
||||
Args:
|
||||
tpu: A string corresponding to the TPU to use. It can be the TPU name or
|
||||
TPU worker gRPC address. If not set, it will try automatically resolve
|
||||
the TPU address on Cloud TPUs.
|
||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||
will try to discover from the GCE metadata service.
|
||||
project: Name of the GCP project containing Cloud TPUs. If omitted or
|
||||
empty, we will try to discover the project name of the GCE VM from the
|
||||
GCE metadata service.
|
||||
|
||||
Returns:
|
||||
An instance of TPUClusterResolver object.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no TPU devices found in eager mode.
|
||||
"""
|
||||
resolver = TPUClusterResolver(tpu, zone, project)
|
||||
from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top
|
||||
remote.connect_to_cluster(resolver)
|
||||
from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
return resolver
|
||||
|
||||
@staticmethod
|
||||
def _get_device_dict_and_cores(devices):
|
||||
"""Returns a dict of hosts to cores and total cores given devices names.
|
||||
|
@ -110,12 +153,9 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
|||
for the IP addresses and ports of each Cloud TPU listed.
|
||||
|
||||
Args:
|
||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
||||
string, the string 'local', or a string that begins with 'grpc://', then
|
||||
it is assumed to not correspond with a Cloud TPU and will instead be
|
||||
passed as the session master and no ClusterSpec propagation will be
|
||||
done. In the future, this may also support a list of strings when
|
||||
multiple Cloud TPUs are used.
|
||||
tpu: A string corresponding to the TPU to use. It can be the TPU name or
|
||||
TPU worker gRPC address. If not set, it will try automatically resolve
|
||||
the TPU address on Cloud TPUs.
|
||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||
will try to discover from the GCE metadata service.
|
||||
|
|
|
@ -338,7 +338,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||
if context.executing_eagerly():
|
||||
# In async remote eager, we want to sync the exectors before exiting the
|
||||
# program.
|
||||
atexit.register(context.async_wait)
|
||||
def async_wait():
|
||||
if context.context()._context_handle is not None: # pylint: disable=protected-access
|
||||
context.async_wait()
|
||||
atexit.register(async_wait)
|
||||
|
||||
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
|
||||
# a dataset to multiple devices exists.
|
||||
|
|
|
@ -24,6 +24,7 @@ from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import T
|
|||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import topology
|
||||
|
@ -47,8 +48,8 @@ def initialize_tpu_system(cluster_resolver=None):
|
|||
The tf.tpu.Topology object for the topology of the TPU cluster.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no TPU devices found for eager execution or if run in a
|
||||
tf.function.
|
||||
RuntimeError: If running inside a tf.function.
|
||||
NotFoundError: If no TPU devices found in eager mode.
|
||||
"""
|
||||
job = None
|
||||
if cluster_resolver is None:
|
||||
|
@ -93,8 +94,15 @@ def initialize_tpu_system(cluster_resolver=None):
|
|||
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
|
||||
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
|
||||
# devices available.
|
||||
try:
|
||||
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
|
||||
output = _tpu_init_fn()
|
||||
context.async_wait()
|
||||
except errors.InvalidArgumentError as e:
|
||||
raise errors.NotFoundError(
|
||||
None, None,
|
||||
"TPUs not found in the cluster. Failed in initialization: "
|
||||
+ str(e))
|
||||
|
||||
# Clear out the eager context caches since the memory is invalid now.
|
||||
logging.info("Clearing out eager caches")
|
||||
|
|
|
@ -15,6 +15,10 @@ tf_class {
|
|||
name: "cluster_spec"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "connect"
|
||||
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_job_name"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
@ -15,6 +15,10 @@ tf_class {
|
|||
name: "cluster_spec"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "connect"
|
||||
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_job_name"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
|
|
@ -11,7 +11,7 @@ package(
|
|||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
tpu_module = "tpu.,distribute.tpu_strategy"
|
||||
tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss"
|
||||
|
||||
py_library(
|
||||
name = "tf_doctest_lib",
|
||||
|
|
Loading…
Reference in New Issue