Only call initialize_tpu_system once per process.

This is an expensive call on Cloud TPU.
This opens us up to a very ugly class of error if one test leaves the TPU in an inconsistent state, causing cascading failures, but is likely worth the speedup. It is unlikely that any one test will saturate TPU memory.

Experimentally seems to reduce runtime from 1h20 to 43m, notably reducing //tensorflow/python/distribute:strategy_gather_test_tpu from 537s to 25s.

PiperOrigin-RevId: 339570736
Change-Id: I777db704f5504c62e6c436c216892decd66776b9
This commit is contained in:
Revan Sopher 2020-10-28 17:25:22 -07:00 committed by TensorFlower Gardener
parent 20e248e158
commit e91327a5b5

View File

@ -41,6 +41,7 @@ from tensorflow.python.util.tf_export import tf_export
_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations."
_did_connect_to_cluster = False
_topology = None
CollectiveAllReduceExtended = (
collective_all_reduce_strategy.CollectiveAllReduceExtended)
@ -76,6 +77,7 @@ def _get_tpu_strategy_creator(steps_per_run,
def _create_tpu_strategy():
FLAGS = flags.FLAGS # pylint: disable=invalid-name
global _did_connect_to_cluster
global _topology
try:
# Attempt to locally discover the TPU. This will fail for Cloud TPU, in
@ -93,16 +95,16 @@ def _get_tpu_strategy_creator(steps_per_run,
)
# Only connect once per process, rather than per test method.
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
if not _did_connect_to_cluster:
if not _did_connect_to_cluster:
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
remote.connect_to_cluster(resolver)
_did_connect_to_cluster = True
_topology = tpu_strategy_util.initialize_tpu_system(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = None
if use_single_core:
device_assignment = device_assignment_lib.DeviceAssignment(
topology,
_topology,
core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT)
# Steps per run is only supported in TF 1.x