Allow tpu.experimental.initialize_tpu_system
to be called inside tf.function.
PiperOrigin-RevId: 336017891 Change-Id: Id99d03800c47f45893f18b9689784d8d1490ab03
This commit is contained in:
parent
fca5923fab
commit
bbd18b8180
@ -45,7 +45,8 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
|
||||
which provides information about the TPU cluster.
|
||||
Returns:
|
||||
The tf.tpu.Topology object for the topology of the TPU cluster.
|
||||
The tf.tpu.Topology object for the topology of the TPU cluster. If called
|
||||
inside tf.function, it returns the serialized topology object instead.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If running inside a tf.function.
|
||||
@ -72,17 +73,17 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
|
||||
logging.info("Initializing the TPU system: %s", tpu_name)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# This function looks as it is for the following non-intuitive reasons.
|
||||
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
|
||||
# DistributedTPURewritePass. This pass actually adds real ops that
|
||||
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
|
||||
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
|
||||
if tpu_name not in _LOCAL_MASTERS:
|
||||
# Explicitly place the tpu.initialize_system in the first worker to
|
||||
# avoid the output node match multiple devices error.
|
||||
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
|
||||
# This function looks as it is for the following non-intuitive reasons.
|
||||
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
|
||||
# DistributedTPURewritePass. This pass actually adds real ops that
|
||||
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
|
||||
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
|
||||
if tpu_name not in _LOCAL_MASTERS:
|
||||
# Explicitly place the tpu.initialize_system in the first worker to
|
||||
# avoid the output node match multiple devices error.
|
||||
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
|
||||
|
||||
if context.executing_eagerly():
|
||||
@function.defun
|
||||
def _tpu_init_fn():
|
||||
# In TF1, we usually close chips when compilation fails to clear the data
|
||||
@ -121,8 +122,13 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
with session_lib.Session(config=session_config, target=master) as sess:
|
||||
serialized_topology = sess.run(tpu.initialize_system())
|
||||
else:
|
||||
raise RuntimeError("initialize_tpu_system is not supported within "
|
||||
"tf.functions.")
|
||||
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
|
||||
serialized_topology = tpu.initialize_system(
|
||||
job=job, compilation_failure_closes_chips=False)
|
||||
# If initialize_tpu_system is called inside tf.function, we only return
|
||||
# the serialized topology object as the tf.tpu.Topology object has to be
|
||||
# constructed in eager mode.
|
||||
return serialized_topology
|
||||
|
||||
logging.info("Finished initializing TPU system.")
|
||||
tpu_topology = topology.Topology(serialized=serialized_topology)
|
||||
|
Loading…
Reference in New Issue
Block a user