From 43dcd3dc3ee4b090832455acf43e8dd483a6117b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Fri, 12 Jul 2019 03:48:28 -0700 Subject: [PATCH] Use current device scope in `initialize_tpu_system` if called without a cluster resolver. PiperOrigin-RevId: 257776495 --- tensorflow/python/tpu/tpu_strategy_util.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index 0f2f93deac5..6b62f55b5bf 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import topology @@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None): Raises: RuntimeError: If no TPU devices found for eager execution. """ + job = None if cluster_resolver is None: + # If no cluster resolver is specified, and running eagerly, execute the init + # ops in the current device scope. + if context.executing_eagerly(): + curr_device = device.DeviceSpec.from_string(context.context().device_name) + if curr_device.job is not None: + job = "{}/replica:0/task:0".format(curr_device.job) + cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) @@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None): # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. - job = None if tpu_name not in _LOCAL_MASTERS: # Explicitly place the tpu.initialize_system in the first worker to # avoid the output node match multiple devices error.