From e91327a5b5676ac24cd0e1883518a5fe22567bdb Mon Sep 17 00:00:00 2001 From: Revan Sopher Date: Wed, 28 Oct 2020 17:25:22 -0700 Subject: [PATCH] Only call initialize_tpu_system once per process. This is an expensive call on Cloud TPU. This opens us up to a very ugly class of error if one test leaves the TPU in an inconsistent state, causing cascading failures, but is likely worth the speedup. It is unlikely that any one test will saturate TPU memory. Experimentally seems to reduce runtime from 1h20 to 43m, notably reducing //tensorflow/python/distribute:strategy_gather_test_tpu from 537s to 25s. PiperOrigin-RevId: 339570736 Change-Id: I777db704f5504c62e6c436c216892decd66776b9 --- tensorflow/python/distribute/strategy_combinations.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index a9a66bc491c..1c014dd6edf 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -41,6 +41,7 @@ from tensorflow.python.util.tf_export import tf_export _TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations." _did_connect_to_cluster = False +_topology = None CollectiveAllReduceExtended = ( collective_all_reduce_strategy.CollectiveAllReduceExtended) @@ -76,6 +77,7 @@ def _get_tpu_strategy_creator(steps_per_run, def _create_tpu_strategy(): FLAGS = flags.FLAGS # pylint: disable=invalid-name global _did_connect_to_cluster + global _topology try: # Attempt to locally discover the TPU. This will fail for Cloud TPU, in @@ -93,16 +95,16 @@ def _get_tpu_strategy_creator(steps_per_run, ) # Only connect once per process, rather than per test method. - if getattr(FLAGS, "tpu", "") or did_automatically_resolve: - if not _did_connect_to_cluster: + if not _did_connect_to_cluster: + if getattr(FLAGS, "tpu", "") or did_automatically_resolve: remote.connect_to_cluster(resolver) _did_connect_to_cluster = True + _topology = tpu_strategy_util.initialize_tpu_system(resolver) - topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( - topology, + _topology, core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x