Cache the TPU systems that have already been initialized.
PiperOrigin-RevId: 239095653
This commit is contained in:
parent
9c234792c0
commit
b7a1100569
@ -34,6 +34,9 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_INITIALIZED_TPU_SYSTEMS = {}
|
||||
|
||||
|
||||
def get_first_tpu_host_device(cluster_resolver):
|
||||
"""Get the device spec for the first TPU host."""
|
||||
if context.executing_eagerly():
|
||||
@ -65,6 +68,13 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
"""
|
||||
if cluster_resolver is None:
|
||||
cluster_resolver = TPUClusterResolver("")
|
||||
assert isinstance(cluster_resolver, TPUClusterResolver)
|
||||
|
||||
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
|
||||
if tpu_name in _INITIALIZED_TPU_SYSTEMS:
|
||||
logging.warning("TPU system %s has already been initialized. "
|
||||
"Reinitializing the TPU can cause previously created "
|
||||
"variables on TPU to be lost.")
|
||||
|
||||
logging.info("Initializing the TPU system.")
|
||||
|
||||
@ -100,4 +110,7 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
serialized_topology = sess.run(tpu.initialize_system())
|
||||
|
||||
logging.info("Finished initializing TPU system.")
|
||||
return topology.Topology(serialized=serialized_topology)
|
||||
tpu_topology = topology.Topology(serialized=serialized_topology)
|
||||
_INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
|
||||
|
||||
return tpu_topology
|
||||
|
Loading…
Reference in New Issue
Block a user