diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index 0f2f93deac5..6b62f55b5bf 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import topology @@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None): Raises: RuntimeError: If no TPU devices found for eager execution. """ + job = None if cluster_resolver is None: + # If no cluster resolver is specified, and running eagerly, execute the init + # ops in the current device scope. + if context.executing_eagerly(): + curr_device = device.DeviceSpec.from_string(context.context().device_name) + if curr_device.job is not None: + job = "{}/replica:0/task:0".format(curr_device.job) + cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) @@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None): # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. - job = None if tpu_name not in _LOCAL_MASTERS: # Explicitly place the tpu.initialize_system in the first worker to # avoid the output node match multiple devices error.