Allow tpu.experimental.initialize_tpu_system to be called inside tf.function.

PiperOrigin-RevId: 336017891
Change-Id: Id99d03800c47f45893f18b9689784d8d1490ab03
This commit is contained in:
Ruoxin Sang 2020-10-07 22:13:55 -07:00 committed by TensorFlower Gardener
parent fca5923fab
commit bbd18b8180

View File

@ -45,7 +45,8 @@ def initialize_tpu_system(cluster_resolver=None):
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
Returns:
The tf.tpu.Topology object for the topology of the TPU cluster.
The tf.tpu.Topology object for the topology of the TPU cluster. If called
inside tf.function, it returns the serialized topology object instead.
Raises:
RuntimeError: If running inside a tf.function.
@ -72,17 +73,17 @@ def initialize_tpu_system(cluster_resolver=None):
logging.info("Initializing the TPU system: %s", tpu_name)
if context.executing_eagerly():
# This function looks as it is for the following non-intuitive reasons.
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
# DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.initialize_system in the first worker to
# avoid the output node match multiple devices error.
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
# This function looks as it is for the following non-intuitive reasons.
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
# DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.initialize_system in the first worker to
# avoid the output node match multiple devices error.
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
if context.executing_eagerly():
@function.defun
def _tpu_init_fn():
# In TF1, we usually close chips when compilation fails to clear the data
@ -121,8 +122,13 @@ def initialize_tpu_system(cluster_resolver=None):
with session_lib.Session(config=session_config, target=master) as sess:
serialized_topology = sess.run(tpu.initialize_system())
else:
raise RuntimeError("initialize_tpu_system is not supported within "
"tf.functions.")
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
serialized_topology = tpu.initialize_system(
job=job, compilation_failure_closes_chips=False)
# If initialize_tpu_system is called inside tf.function, we only return
# the serialized topology object as the tf.tpu.Topology object has to be
# constructed in eager mode.
return serialized_topology
logging.info("Finished initializing TPU system.")
tpu_topology = topology.Topology(serialized=serialized_topology)