Cache the TPU systems that have already been initialized.

PiperOrigin-RevId: 239095653
This commit is contained in:
Sourabh Bajaj 2019-03-18 17:27:39 -07:00 committed by TensorFlower Gardener
parent 9c234792c0
commit b7a1100569

View File

@ -34,6 +34,9 @@ from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
_INITIALIZED_TPU_SYSTEMS = {}
def get_first_tpu_host_device(cluster_resolver):
"""Get the device spec for the first TPU host."""
if context.executing_eagerly():
@ -65,6 +68,13 @@ def initialize_tpu_system(cluster_resolver=None):
"""
if cluster_resolver is None:
cluster_resolver = TPUClusterResolver("")
assert isinstance(cluster_resolver, TPUClusterResolver)
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
if tpu_name in _INITIALIZED_TPU_SYSTEMS:
logging.warning("TPU system %s has already been initialized. "
"Reinitializing the TPU can cause previously created "
"variables on TPU to be lost.")
logging.info("Initializing the TPU system.")
@ -100,4 +110,7 @@ def initialize_tpu_system(cluster_resolver=None):
serialized_topology = sess.run(tpu.initialize_system())
logging.info("Finished initializing TPU system.")
return topology.Topology(serialized=serialized_topology)
tpu_topology = topology.Topology(serialized=serialized_topology)
_INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
return tpu_topology