From 16d2b56de211a28805eca2413474c6ca535ff3a5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Sat, 29 Aug 2020 14:28:13 -0700
Subject: [PATCH] Add op definition and bridge implementation for
 TensorStridedSliceUpdate.

PiperOrigin-RevId: 329122036
Change-Id: I616b81ec2b328eddc77f662c77ae956972cd265a
---
 .../compiler/jit/mark_for_compilation_pass.cc |  1 -
 .../mlir/tensorflow/ir/tf_generated_ops.td    | 34 -------------------
 .../tf2xla/kernels/strided_slice_op.cc        | 19 ++---------
 3 files changed, 2 insertions(+), 52 deletions(-)

diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 849019a76c4..03ac7b0a59a 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -2048,7 +2048,6 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
                                      "TensorScatterAdd",
                                      "TensorScatterSub",
                                      "TensorScatterUpdate",
-                                     "TensorStridedSliceUpdate",
                                      "TridiagonalSolve",
                                      "TruncatedNormal",
                                      "UpperBound",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 9b201ba878e..faf7d428aea 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -11651,40 +11651,6 @@ On GPU, if an out of bound index is found, the index is ignored.
   ];
 }
 
-def TF_TensorStridedSliceUpdateOp : TF_Op<"TensorStridedSliceUpdate", [NoSideEffect]> {
-  let summary = "Assign `value` to the sliced l-value reference of `input`.";
-
-  let description = [{
-The values of `value` are assigned to the positions in the tensor `input` that
-are selected by the slice parameters. The slice parameters `begin` `end`
-`strides` etc. work exactly as in `StridedSlice`.
-
-NOTE this op currently does not support broadcasting and so `value`'s shape
-must be exactly the shape produced by the slice of `input`.
-  }];
-
-  let arguments = (ins
-    TF_Tensor:$input,
-    TF_I32OrI64Tensor:$begin,
-    TF_I32OrI64Tensor:$end,
-    TF_I32OrI64Tensor:$strides,
-    TF_Tensor:$value,
-
-    DefaultValuedAttr<I64Attr, "0">:$begin_mask,
-    DefaultValuedAttr<I64Attr, "0">:$end_mask,
-    DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
-    DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
-    DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
-  );
-
-  let results = (outs
-    TF_Tensor:$output
-  );
-
-  TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
-  TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
-}
-
 def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
   let summary = "Constructs a tensor by tiling a given tensor.";
 
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 6e5afd98e9d..268317d84fc 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -446,12 +446,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
 
     TensorShape lhs_shape;
     xla::XlaOp lhs;
-    if (ctx->input_type(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
-    } else {
-      lhs_shape = ctx->InputShape(0);
-      lhs = ctx->Input(0);
-    }
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
 
     const TensorShape rhs_shape = ctx->InputShape(4);
 
@@ -509,11 +504,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
 
     lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin);
 
-    if (ctx->input_type(0) == DT_RESOURCE) {
-      OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
-    } else {
-      ctx->SetOutput(0, lhs);
-    }
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
   }
 
  private:
@@ -529,11 +520,5 @@ REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
                     .CompileTimeConstantInput("strides"),
                 StridedSliceAssignOp);
 
-REGISTER_XLA_OP(Name("TensorStridedSliceUpdate")
-                    .CompileTimeConstantInput("begin")
-                    .CompileTimeConstantInput("end")
-                    .CompileTimeConstantInput("strides"),
-                StridedSliceAssignOp);
-
 }  // namespace
 }  // namespace tensorflow