From 85bf5f7c202f1c656ebf169592aa4a0a9c022e8a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 14 May 2020 02:32:13 -0700 Subject: [PATCH] Return a meaningful error for dynamic shape inputs with outside compilation head extraction in TPUs. PiperOrigin-RevId: 311495416 Change-Id: I42b12ac545224c32e770d963a5f3f333ba280531 --- .../python/distribute/tpu_strategy_test.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 6c93e29c028..de4c975d5ef 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -28,7 +28,6 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.eager import test -from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -141,9 +140,6 @@ class TPUStrategyTest(test.TestCase): # for non-local TPU. if FLAGS.tpu: self.skipTest("Recovery fails for non-local TPU, see b/148150981") - - # Disable automatic outside compilation. - config.set_soft_device_placement(False) strategy = get_tpu_strategy() @def_function.function @@ -168,28 +164,6 @@ class TPUStrategyTest(test.TestCase): good_run() - def test_dynamic_shape_with_outside_compilation_failure(self): - # Enable automatic outside compilation. - config.set_soft_device_placement(True) - strategy = get_tpu_strategy() - dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch( - 2, drop_remainder=False) - dataset = strategy.experimental_distribute_dataset(dataset) - iterator = iter(dataset) - - @def_function.function - def train_fn(iterator): - - def step_fn(inputs): - _, inputs = inputs - return math_ops.reduce_sum(inputs) - - return strategy.experimental_local_results( - strategy.run(step_fn, args=(next(iterator),))) - - with self.assertRaisesRegex(errors.InternalError, "Compilation failure"): - logging.info(train_fn(iterator)) - def test_computation_on_subset_cores(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver)