From 1619f2f19f3280f36bd351142314f8a8248ac903 Mon Sep 17 00:00:00 2001
From: Russell Power <power@google.com>
Date: Mon, 31 Aug 2020 09:16:55 -0700
Subject: [PATCH] Implement TensorStridedSliceAssign XLA op.

PiperOrigin-RevId: 329315230
Change-Id: I5aca22493f5fa38fcd03a3f78f6d9e9afdaadb8b
---
 .../compiler/jit/mark_for_compilation_pass.cc |  4 ++-
 .../mlir/tensorflow/ir/tf_generated_ops.td    | 34 +++++++++++++++++++
 .../tf2xla/kernels/strided_slice_op.cc        | 19 +++++++++--
 .../python/kernel_tests/array_ops_test.py     |  4 +++
 4 files changed, 58 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 03ac7b0a59a..af0a192639c 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -1834,7 +1834,9 @@ absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
       "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
       "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
       "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
-      "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
+      "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex",
+      "TensorStridedSliceUpdate",
+     }}};
   // clang-format on
   return result;
 }
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index faf7d428aea..9b201ba878e 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -11651,6 +11651,40 @@ On GPU, if an out of bound index is found, the index is ignored.
   ];
 }
 
+def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
+  let summary = "Assign `value` to the sliced l-value reference of `input`.";
+
+  let description = [{
+The values of `value` are assigned to the positions in the tensor `input` that
+are selected by the slice parameters. The slice parameters `begin` `end`
+`strides` etc. work exactly as in `StridedSlice`.
+
+NOTE this op currently does not support broadcasting and so `value`'s shape
+must be exactly the shape produced by the slice of `input`.
+  }];
+
+  let arguments = (ins
+    TF_Tensor:$input,
+    TF_I32OrI64Tensor:$begin,
+    TF_I32OrI64Tensor:$end,
+    TF_I32OrI64Tensor:$strides,
+    TF_Tensor:$value,
+
+    DefaultValuedAttr<I64Attr, "0">:$begin_mask,
+    DefaultValuedAttr<I64Attr, "0">:$end_mask,
+    DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
+    DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
+    DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
+  );
+
+  let results = (outs
+    TF_Tensor:$output
+  );
+
+  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+  TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
+}
+
 def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
   let summary = "Constructs a tensor by tiling a given tensor.";
 
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 268317d84fc..6e5afd98e9d 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -446,7 +446,12 @@ class StridedSliceAssignOp : public XlaOpKernel {
 
     TensorShape lhs_shape;
     xla::XlaOp lhs;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
+    if (ctx->input_type(0) == DT_RESOURCE) {
+      OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
+    } else {
+      lhs_shape = ctx->InputShape(0);
+      lhs = ctx->Input(0);
+    }
 
     const TensorShape rhs_shape = ctx->InputShape(4);
 
@@ -504,7 +509,11 @@ class StridedSliceAssignOp : public XlaOpKernel {
 
     lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
 
-    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
+    if (ctx->input_type(0) == DT_RESOURCE) {
+      OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
+    } else {
+      ctx->SetOutput(0, lhs);
+    }
   }
 
  private:
@@ -520,5 +529,11 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
                     .CompileTimeConstantInput("strides"),
                 StridedSliceAssignOp);
 
+REGISTER_XLA_OP(Name("TensorStridedSliceUpdate")
+                    .CompileTimeConstantInput("begin")
+                    .CompileTimeConstantInput("end")
+                    .CompileTimeConstantInput("strides"),
+                StridedSliceAssignOp);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 391930e20d5..7714b010147 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1234,6 +1234,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       with self.assertRaises(ValueError):
         sess.run(v[:].assign(too_small_val))
 
+  @test_util.disable_xla("b/123559667")
   @test_util.run_in_graph_and_eager_modes
   def testTensorStridedSliceUpdateWithInputForward(self):
     """Tests tensor_strided_slice_update with input-forwarding taking effect."""
@@ -1243,6 +1244,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       return gen_array_ops.tensor_strided_slice_update(y, [0], [1], [1], [0])
     self.assertAllEqual([0, 1], self.evaluate(assign(array_ops.zeros([2]))))
 
+  @test_util.disable_xla("b/123559667")
   @test_util.run_in_graph_and_eager_modes
   def testTensorStridedSliceUpdateNoInputForward(self):
     """Tests tensor_strided_slice_update with no input-forwarding."""
@@ -1254,6 +1256,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     ans = y + z
     self.assertAllClose([1.6, 2.6], self.evaluate(ans))
 
+  @test_util.disable_xla("b/123559667")
   def testTensorStridedSliceUpdateGradSimple(self):
     original = constant_op.constant([0.2, 0.3])
     updates = constant_op.constant([0.4])
@@ -1272,6 +1275,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase, parameterized.TestCase):
           ([4], [5], [3], [1], [3], 1, 0, 0, 0, 0),
           ([2, 2, 3, 2], [0, 0, 1], [1, 0, 2], [1, 0, 1], [2, 3], 0, 0, 2, 0, 5)
       ]))
+  @test_util.disable_xla("b/123559667")
   def testTensorStridedSliceUpdateGrad(
       self, shape, begin, end, strides, updates_shape, *args):
     with self.cached_session():