Use current device scope in initialize_tpu_system
if called without a cluster resolver.
PiperOrigin-RevId: 257776495
This commit is contained in:
parent
323564b496
commit
43dcd3dc3e
@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import topology
|
||||
@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
Raises:
|
||||
RuntimeError: If no TPU devices found for eager execution.
|
||||
"""
|
||||
job = None
|
||||
if cluster_resolver is None:
|
||||
# If no cluster resolver is specified, and running eagerly, execute the init
|
||||
# ops in the current device scope.
|
||||
if context.executing_eagerly():
|
||||
curr_device = device.DeviceSpec.from_string(context.context().device_name)
|
||||
if curr_device.job is not None:
|
||||
job = "{}/replica:0/task:0".format(curr_device.job)
|
||||
|
||||
cluster_resolver = TPUClusterResolver("")
|
||||
assert isinstance(cluster_resolver, TPUClusterResolver)
|
||||
|
||||
@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
# DistributedTPURewritePass. This pass actually adds real ops that
|
||||
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
|
||||
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
|
||||
job = None
|
||||
if tpu_name not in _LOCAL_MASTERS:
|
||||
# Explicitly place the tpu.initialize_system in the first worker to
|
||||
# avoid the output node match multiple devices error.
|
||||
|
Loading…
Reference in New Issue
Block a user