Use current device scope in initialize_tpu_system if called without a cluster resolver.

PiperOrigin-RevId: 257776495
This commit is contained in:
Chris Jones 2019-07-12 03:48:28 -07:00 committed by TensorFlower Gardener
parent 323564b496
commit 43dcd3dc3e

View File

@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import device
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology from tensorflow.python.tpu import topology
@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None):
Raises: Raises:
RuntimeError: If no TPU devices found for eager execution. RuntimeError: If no TPU devices found for eager execution.
""" """
job = None
if cluster_resolver is None: if cluster_resolver is None:
# If no cluster resolver is specified, and running eagerly, execute the init
# ops in the current device scope.
if context.executing_eagerly():
curr_device = device.DeviceSpec.from_string(context.context().device_name)
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
cluster_resolver = TPUClusterResolver("") cluster_resolver = TPUClusterResolver("")
assert isinstance(cluster_resolver, TPUClusterResolver) assert isinstance(cluster_resolver, TPUClusterResolver)
@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None):
# DistributedTPURewritePass. This pass actually adds real ops that # DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it. # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
job = None
if tpu_name not in _LOCAL_MASTERS: if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.initialize_system in the first worker to # Explicitly place the tpu.initialize_system in the first worker to
# avoid the output node match multiple devices error. # avoid the output node match multiple devices error.