Use current device scope in initialize_tpu_system if called without a cluster resolver.

PiperOrigin-RevId: 257776495
This commit is contained in:
Chris Jones 2019-07-12 03:48:28 -07:00 committed by TensorFlower Gardener
parent 323564b496
commit 43dcd3dc3e

View File

@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import device
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology
@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None):
Raises:
RuntimeError: If no TPU devices found for eager execution.
"""
job = None
if cluster_resolver is None:
# If no cluster resolver is specified, and running eagerly, execute the init
# ops in the current device scope.
if context.executing_eagerly():
curr_device = device.DeviceSpec.from_string(context.context().device_name)
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
cluster_resolver = TPUClusterResolver("")
assert isinstance(cluster_resolver, TPUClusterResolver)
@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None):
# DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
job = None
if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.initialize_system in the first worker to
# avoid the output node match multiple devices error.