From ca18db7f3f5057bb83c41f4710d7a6a75224300d Mon Sep 17 00:00:00 2001 From: Ruoxin Sang <rxsang@google.com> Date: Thu, 14 May 2020 01:38:34 -0700 Subject: [PATCH] Return a meaningful error for dynamic shape inputs with outside compilation head extraction in TPUs. PiperOrigin-RevId: 311490072 Change-Id: Idc7bf1764aba1fcbfcf830e36a5b575b387923d7 --- .../python/distribute/tpu_strategy_test.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index de4c975d5ef..6c93e29c028 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase): # for non-local TPU. if FLAGS.tpu: self.skipTest("Recovery fails for non-local TPU, see b/148150981") + + # Disable automatic outside compilation. + config.set_soft_device_placement(False) strategy = get_tpu_strategy() @def_function.function @@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase): good_run() + def test_dynamic_shape_with_outside_compilation_failure(self): + # Enable automatic outside compilation. + config.set_soft_device_placement(True) + strategy = get_tpu_strategy() + dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch( + 2, drop_remainder=False) + dataset = strategy.experimental_distribute_dataset(dataset) + iterator = iter(dataset) + + @def_function.function + def train_fn(iterator): + + def step_fn(inputs): + _, inputs = inputs + return math_ops.reduce_sum(inputs) + + return strategy.experimental_local_results( + strategy.run(step_fn, args=(next(iterator),))) + + with self.assertRaisesRegex(errors.InternalError, "Compilation failure"): + logging.info(train_fn(iterator)) + def test_computation_on_subset_cores(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver)