diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index de4c975d5ef..6c93e29c028 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase): # for non-local TPU. if FLAGS.tpu: self.skipTest("Recovery fails for non-local TPU, see b/148150981") + + # Disable automatic outside compilation. + config.set_soft_device_placement(False) strategy = get_tpu_strategy() @def_function.function @@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase): good_run() + def test_dynamic_shape_with_outside_compilation_failure(self): + # Enable automatic outside compilation. + config.set_soft_device_placement(True) + strategy = get_tpu_strategy() + dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch( + 2, drop_remainder=False) + dataset = strategy.experimental_distribute_dataset(dataset) + iterator = iter(dataset) + + @def_function.function + def train_fn(iterator): + + def step_fn(inputs): + _, inputs = inputs + return math_ops.reduce_sum(inputs) + + return strategy.experimental_local_results( + strategy.run(step_fn, args=(next(iterator),))) + + with self.assertRaisesRegex(errors.InternalError, "Compilation failure"): + logging.info(train_fn(iterator)) + def test_computation_on_subset_cores(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver)