From a0c8fee0e68900f70f70ff301da398543421a10c Mon Sep 17 00:00:00 2001
From: Saurabh Saxena <srbs@google.com>
Date: Tue, 3 Mar 2020 16:22:40 -0800
Subject: [PATCH] Support Identity/StopGradient ops for TensorLists in tf2xla
 bridge. This is needed to support back_prop=False in v2 control flow.

PiperOrigin-RevId: 298720533
Change-Id: I33cd2f2603cff07193c0275878ff22c4c8338fa8
---
 tensorflow/compiler/tests/while_test.py          | 16 ++++++++++++++++
 .../compiler/tf2xla/kernels/identity_op.cc       | 16 +++++++++++-----
 2 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py
index 420dc04bec3..f1f8b6c353c 100644
--- a/tensorflow/compiler/tests/while_test.py
+++ b/tensorflow/compiler/tests/while_test.py
@@ -240,6 +240,22 @@ class WhileTest(xla_test.XLATestCase):
       self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
       xla_context.Exit()
 
+  @test_util.enable_control_flow_v2
+  def testMapBackPropFalse(self):
+    if is_compile_on_demand():
+      self.skipTest("list_ops are not supported in cpu_ondemand")
+    with self.session(), self.test_scope():
+      xla_context = control_flow_ops.XLAControlFlowContext()
+      xla_context.Enter()
+      nums = [1, 2, 3, 4, 5, 6]
+      elems = constant_op.constant(nums, name="data")
+      r = map_fn.map_fn(
+          lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
+          elems,
+          back_prop=False)
+      self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
+      xla_context.Exit()
+
 
 def is_compile_on_demand():
   return ("TF_XLA_FLAGS" in os.environ and
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index 8b27e8e85a3..38d8056d3e5 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 
@@ -25,10 +26,15 @@ class IdentityOp : public XlaOpKernel {
 
   void Compile(XlaOpKernelContext* ctx) override {
     for (int i = 0; i < ctx->num_inputs(); ++i) {
-      // Forwards using the underlying op_kernel_context so both tensor and
-      // resource values are forwarded correctly.
-      ctx->op_kernel_context()->set_output(i,
-                                           ctx->op_kernel_context()->input(i));
+      if (IsTensorListInput(ctx, i)) {
+        ctx->SetTensorListOutput(i, ctx->Input(i));
+      } else {
+        DCHECK(ctx->input_type(i) != DT_VARIANT);
+        // Forwards using the underlying op_kernel_context so both tensor and
+        // resource values are forwarded correctly.
+        ctx->op_kernel_context()->set_output(
+            i, ctx->op_kernel_context()->input(i));
+      }
     }
   }
 
@@ -48,7 +54,7 @@ REGISTER_XLA_OP(Name("IdentityN")
                 IdentityOp);
 REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
 REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
-REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
+REGISTER_XLA_OP(Name("StopGradient").AllowVariantTypes(), IdentityOp);
 REGISTER_XLA_OP(Name("Snapshot"), IdentityOp);
 
 }  // namespace