From bbd18b8180a7e63a082ffde52ba3d7307f1740c7 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 7 Oct 2020 22:13:55 -0700 Subject: [PATCH] Allow `tpu.experimental.initialize_tpu_system` to be called inside tf.function. PiperOrigin-RevId: 336017891 Change-Id: Id99d03800c47f45893f18b9689784d8d1490ab03 --- tensorflow/python/tpu/tpu_strategy_util.py | 32 +++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index d4ba15ee2db..b52d22c60dc 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -45,7 +45,8 @@ def initialize_tpu_system(cluster_resolver=None): cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: - The tf.tpu.Topology object for the topology of the TPU cluster. + The tf.tpu.Topology object for the topology of the TPU cluster. If called + inside tf.function, it returns the serialized topology object instead. Raises: RuntimeError: If running inside a tf.function. @@ -72,17 +73,17 @@ def initialize_tpu_system(cluster_resolver=None): logging.info("Initializing the TPU system: %s", tpu_name) - if context.executing_eagerly(): - # This function looks as it is for the following non-intuitive reasons. - # tpu.initialize_system creates a dummy op whose sole purpose is to trigger - # DistributedTPURewritePass. This pass actually adds real ops that - # initialize the TPU system. Thus, we can't simply run tpu.initialize_system - # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. - if tpu_name not in _LOCAL_MASTERS: - # Explicitly place the tpu.initialize_system in the first worker to - # avoid the output node match multiple devices error. - job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) + # This function looks as it is for the following non-intuitive reasons. + # tpu.initialize_system creates a dummy op whose sole purpose is to trigger + # DistributedTPURewritePass. This pass actually adds real ops that + # initialize the TPU system. Thus, we can't simply run tpu.initialize_system + # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. + if tpu_name not in _LOCAL_MASTERS: + # Explicitly place the tpu.initialize_system in the first worker to + # avoid the output node match multiple devices error. + job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) + if context.executing_eagerly(): @function.defun def _tpu_init_fn(): # In TF1, we usually close chips when compilation fails to clear the data @@ -121,8 +122,13 @@ def initialize_tpu_system(cluster_resolver=None): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) else: - raise RuntimeError("initialize_tpu_system is not supported within " - "tf.functions.") + with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access + serialized_topology = tpu.initialize_system( + job=job, compilation_failure_closes_chips=False) + # If initialize_tpu_system is called inside tf.function, we only return + # the serialized topology object as the tf.tpu.Topology object has to be + # constructed in eager mode. + return serialized_topology logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology)