Add a `TPUClusterResolver.connect` API to simplify TPU initialization.
RELNOTES=Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. PiperOrigin-RevId: 317439811 Change-Id: I2f1a944f3c440356b21da27a72855c969f1c3b3b
This commit is contained in:
parent
27cb9aa834
commit
ae78d9c315
|
@ -52,16 +52,59 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
||||||
"""Cluster Resolver for Google Cloud TPUs.
|
"""Cluster Resolver for Google Cloud TPUs.
|
||||||
|
|
||||||
This is an implementation of cluster resolvers for the Google Cloud TPU
|
This is an implementation of cluster resolvers for the Google Cloud TPU
|
||||||
service. As Cloud TPUs are in alpha, you will need to specify a API definition
|
service.
|
||||||
file for this to consume, in addition to a list of Cloud TPUs in your Google
|
|
||||||
Cloud Platform project.
|
|
||||||
|
|
||||||
TPUClusterResolver supports the following distinct environments:
|
TPUClusterResolver supports the following distinct environments:
|
||||||
Google Compute Engine
|
Google Compute Engine
|
||||||
Google Kubernetes Engine
|
Google Kubernetes Engine
|
||||||
Google internal
|
Google internal
|
||||||
|
|
||||||
|
It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
|
||||||
|
Cloud TPUs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def connect(tpu=None,
|
||||||
|
zone=None,
|
||||||
|
project=None):
|
||||||
|
"""Initializes TPU and returns a TPUClusterResolver.
|
||||||
|
|
||||||
|
This API will connect to remote TPU cluster and initialize the TPU
|
||||||
|
hardwares. Example usage:
|
||||||
|
|
||||||
|
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
|
||||||
|
... tpu='')
|
||||||
|
|
||||||
|
It can be viewed as convenient wrapper of the following code:
|
||||||
|
|
||||||
|
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
||||||
|
>>> tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tpu: A string corresponding to the TPU to use. It can be the TPU name or
|
||||||
|
TPU worker gRPC address. If not set, it will try automatically resolve
|
||||||
|
the TPU address on Cloud TPUs.
|
||||||
|
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||||
|
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||||
|
will try to discover from the GCE metadata service.
|
||||||
|
project: Name of the GCP project containing Cloud TPUs. If omitted or
|
||||||
|
empty, we will try to discover the project name of the GCE VM from the
|
||||||
|
GCE metadata service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of TPUClusterResolver object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no TPU devices found in eager mode.
|
||||||
|
"""
|
||||||
|
resolver = TPUClusterResolver(tpu, zone, project)
|
||||||
|
from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top
|
||||||
|
remote.connect_to_cluster(resolver)
|
||||||
|
from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top
|
||||||
|
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
|
return resolver
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_device_dict_and_cores(devices):
|
def _get_device_dict_and_cores(devices):
|
||||||
"""Returns a dict of hosts to cores and total cores given devices names.
|
"""Returns a dict of hosts to cores and total cores given devices names.
|
||||||
|
@ -110,12 +153,9 @@ class TPUClusterResolver(cluster_resolver.ClusterResolver):
|
||||||
for the IP addresses and ports of each Cloud TPU listed.
|
for the IP addresses and ports of each Cloud TPU listed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tpu: A string corresponding to the TPU to use. If the string is an empty
|
tpu: A string corresponding to the TPU to use. It can be the TPU name or
|
||||||
string, the string 'local', or a string that begins with 'grpc://', then
|
TPU worker gRPC address. If not set, it will try automatically resolve
|
||||||
it is assumed to not correspond with a Cloud TPU and will instead be
|
the TPU address on Cloud TPUs.
|
||||||
passed as the session master and no ClusterSpec propagation will be
|
|
||||||
done. In the future, this may also support a list of strings when
|
|
||||||
multiple Cloud TPUs are used.
|
|
||||||
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
zone: Zone where the TPUs are located. If omitted or empty, we will assume
|
||||||
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
that the zone of the TPU is the same as the zone of the GCE VM, which we
|
||||||
will try to discover from the GCE metadata service.
|
will try to discover from the GCE metadata service.
|
||||||
|
|
|
@ -338,7 +338,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# In async remote eager, we want to sync the exectors before exiting the
|
# In async remote eager, we want to sync the exectors before exiting the
|
||||||
# program.
|
# program.
|
||||||
atexit.register(context.async_wait)
|
def async_wait():
|
||||||
|
if context.context()._context_handle is not None: # pylint: disable=protected-access
|
||||||
|
context.async_wait()
|
||||||
|
atexit.register(async_wait)
|
||||||
|
|
||||||
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
|
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
|
||||||
# a dataset to multiple devices exists.
|
# a dataset to multiple devices exists.
|
||||||
|
|
|
@ -24,6 +24,7 @@ from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import T
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import device
|
from tensorflow.python.framework import device
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import topology
|
from tensorflow.python.tpu import topology
|
||||||
|
@ -47,8 +48,8 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||||
The tf.tpu.Topology object for the topology of the TPU cluster.
|
The tf.tpu.Topology object for the topology of the TPU cluster.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If no TPU devices found for eager execution or if run in a
|
RuntimeError: If running inside a tf.function.
|
||||||
tf.function.
|
NotFoundError: If no TPU devices found in eager mode.
|
||||||
"""
|
"""
|
||||||
job = None
|
job = None
|
||||||
if cluster_resolver is None:
|
if cluster_resolver is None:
|
||||||
|
@ -93,8 +94,15 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||||
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
|
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
|
||||||
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
|
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
|
||||||
# devices available.
|
# devices available.
|
||||||
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
|
try:
|
||||||
output = _tpu_init_fn()
|
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
|
||||||
|
output = _tpu_init_fn()
|
||||||
|
context.async_wait()
|
||||||
|
except errors.InvalidArgumentError as e:
|
||||||
|
raise errors.NotFoundError(
|
||||||
|
None, None,
|
||||||
|
"TPUs not found in the cluster. Failed in initialization: "
|
||||||
|
+ str(e))
|
||||||
|
|
||||||
# Clear out the eager context caches since the memory is invalid now.
|
# Clear out the eager context caches since the memory is invalid now.
|
||||||
logging.info("Clearing out eager caches")
|
logging.info("Clearing out eager caches")
|
||||||
|
|
|
@ -15,6 +15,10 @@ tf_class {
|
||||||
name: "cluster_spec"
|
name: "cluster_spec"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "connect"
|
||||||
|
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_job_name"
|
name: "get_job_name"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
|
@ -15,6 +15,10 @@ tf_class {
|
||||||
name: "cluster_spec"
|
name: "cluster_spec"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "connect"
|
||||||
|
argspec: "args=[\'tpu\', \'zone\', \'project\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_job_name"
|
name: "get_job_name"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
|
|
@ -11,7 +11,7 @@ package(
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
tpu_module = "tpu.,distribute.tpu_strategy"
|
tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss"
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tf_doctest_lib",
|
name = "tf_doctest_lib",
|
||||||
|
|
Loading…
Reference in New Issue