Use current device scope in initialize_tpu_system
if called without a cluster resolver.
PiperOrigin-RevId: 257776495
This commit is contained in:
parent
323564b496
commit
43dcd3dc3e
@ -23,6 +23,7 @@ from tensorflow.python.client import session as session_lib
|
|||||||
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import device
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import topology
|
from tensorflow.python.tpu import topology
|
||||||
@ -48,7 +49,15 @@ def initialize_tpu_system(cluster_resolver=None):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If no TPU devices found for eager execution.
|
RuntimeError: If no TPU devices found for eager execution.
|
||||||
"""
|
"""
|
||||||
|
job = None
|
||||||
if cluster_resolver is None:
|
if cluster_resolver is None:
|
||||||
|
# If no cluster resolver is specified, and running eagerly, execute the init
|
||||||
|
# ops in the current device scope.
|
||||||
|
if context.executing_eagerly():
|
||||||
|
curr_device = device.DeviceSpec.from_string(context.context().device_name)
|
||||||
|
if curr_device.job is not None:
|
||||||
|
job = "{}/replica:0/task:0".format(curr_device.job)
|
||||||
|
|
||||||
cluster_resolver = TPUClusterResolver("")
|
cluster_resolver = TPUClusterResolver("")
|
||||||
assert isinstance(cluster_resolver, TPUClusterResolver)
|
assert isinstance(cluster_resolver, TPUClusterResolver)
|
||||||
|
|
||||||
@ -66,7 +75,6 @@ def initialize_tpu_system(cluster_resolver=None):
|
|||||||
# DistributedTPURewritePass. This pass actually adds real ops that
|
# DistributedTPURewritePass. This pass actually adds real ops that
|
||||||
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
|
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
|
||||||
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
|
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
|
||||||
job = None
|
|
||||||
if tpu_name not in _LOCAL_MASTERS:
|
if tpu_name not in _LOCAL_MASTERS:
|
||||||
# Explicitly place the tpu.initialize_system in the first worker to
|
# Explicitly place the tpu.initialize_system in the first worker to
|
||||||
# avoid the output node match multiple devices error.
|
# avoid the output node match multiple devices error.
|
||||||
|
Loading…
Reference in New Issue
Block a user