From ca18db7f3f5057bb83c41f4710d7a6a75224300d Mon Sep 17 00:00:00 2001
From: Ruoxin Sang <rxsang@google.com>
Date: Thu, 14 May 2020 01:38:34 -0700
Subject: [PATCH] Return a meaningful error for dynamic shape inputs with
 outside compilation head extraction in TPUs.

PiperOrigin-RevId: 311490072
Change-Id: Idc7bf1764aba1fcbfcf830e36a5b575b387923d7
---
 .../python/distribute/tpu_strategy_test.py    | 26 +++++++++++++++++++
 1 file changed, 26 insertions(+)

diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py
index de4c975d5ef..6c93e29c028 100644
--- a/tensorflow/python/distribute/tpu_strategy_test.py
+++ b/tensorflow/python/distribute/tpu_strategy_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function
 from tensorflow.python.eager import remote
 from tensorflow.python.eager import test
+from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase):
     # for non-local TPU.
     if FLAGS.tpu:
       self.skipTest("Recovery fails for non-local TPU, see b/148150981")
+
+    # Disable automatic outside compilation.
+    config.set_soft_device_placement(False)
     strategy = get_tpu_strategy()
 
     @def_function.function
@@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase):
 
     good_run()
 
+  def test_dynamic_shape_with_outside_compilation_failure(self):
+    # Enable automatic outside compilation.
+    config.set_soft_device_placement(True)
+    strategy = get_tpu_strategy()
+    dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch(
+        2, drop_remainder=False)
+    dataset = strategy.experimental_distribute_dataset(dataset)
+    iterator = iter(dataset)
+
+    @def_function.function
+    def train_fn(iterator):
+
+      def step_fn(inputs):
+        _, inputs = inputs
+        return math_ops.reduce_sum(inputs)
+
+      return strategy.experimental_local_results(
+          strategy.run(step_fn, args=(next(iterator),)))
+
+    with self.assertRaisesRegex(errors.InternalError, "Compilation failure"):
+      logging.info(train_fn(iterator))
+
   def test_computation_on_subset_cores(self):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)