From 3446f153cbe9230244467fcee044bfe597e14b8e Mon Sep 17 00:00:00 2001 From: Revan Sopher Date: Wed, 27 May 2020 14:11:01 -0700 Subject: [PATCH] Support Google-internal TPU resolution in strategy combinations. PiperOrigin-RevId: 313461094 Change-Id: I42c3c87a7f8c41d6b11b6de408ddfc44fd8bbda0 --- .../custom_training_loop_gradient_test.py | 1 - .../distribute/strategy_combinations.py | 29 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/distribute/custom_training_loop_gradient_test.py b/tensorflow/python/distribute/custom_training_loop_gradient_test.py index ebf5d440c3e..c2ce2caccd0 100644 --- a/tensorflow/python/distribute/custom_training_loop_gradient_test.py +++ b/tensorflow/python/distribute/custom_training_loop_gradient_test.py @@ -111,7 +111,6 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase, return grads return distribution.experimental_local_results( distribution.run(train_step, args=(x,))) - dist_dataset = distribution.experimental_distribute_dataset(dataset) results = [] for x in dist_dataset: diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index e69c8c7f129..ca7a0d6d1fc 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -47,7 +47,6 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop - FLAGS = flags.FLAGS _did_connect_to_cluster = False @@ -58,16 +57,26 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): def _create_tpu_strategy(): global _did_connect_to_cluster - # These flags will be defined by tpu_test_wrapper.py. - resolver = tpu_cluster_resolver.TPUClusterResolver( - tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", - zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, - project=hasattr(FLAGS, "project") and FLAGS.project or None, - ) + try: + # Attempt to locally discover the TPU. This will fail for Cloud TPU, in + # which case we fall back to the values passed as flags. + resolver = tpu_cluster_resolver.TPUClusterResolver() + did_automatically_resolve = True + except ValueError: + did_automatically_resolve = False + + # These flags will be defined by tpu_test_wrapper.py. + resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", + zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, + project=hasattr(FLAGS, "project") and FLAGS.project or None, + ) + # Only connect once per process, rather than per test method. - if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster: - remote.connect_to_cluster(resolver) - _did_connect_to_cluster = True + if getattr(FLAGS, "tpu", "") or did_automatically_resolve: + if not _did_connect_to_cluster: + remote.connect_to_cluster(resolver) + _did_connect_to_cluster = True topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None