From 41b3e84fa920ba51e7bf321471462d5e6076776f Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Wed, 10 Feb 2021 15:23:45 -0800
Subject: [PATCH] Support more cases in strided slice ops.

- When ranges are dynamic, we can lower strided slice ops into xla dynamic slices.
- Optimize xla buidler dynamism inference to not materialize parameter tuple.
- Export some helper data structures from strided_slice_op.

PiperOrigin-RevId: 356839276
Change-Id: I3b5d9fcb1289496029afbf97e49df49df26ce699
---
 .../tf2xla/kernels/strided_slice_op.cc        | 258 ++++++++++++------
 tensorflow/compiler/xla/client/xla_builder.cc |  26 +-
 tensorflow/core/util/strided_slice_op.cc      |  48 ++--
 tensorflow/core/util/strided_slice_op.h       |  36 ++-
 4 files changed, 242 insertions(+), 126 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 8bb22c6e7b0..d5e7577862b 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/ops_util.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -52,6 +53,145 @@ class StridedSliceOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
   }
 
+  void EmitDynamicSlice(XlaOpKernelContext* ctx,
+                        const absl::InlinedVector<int64, 4>& strides,
+                        TensorShape processing_shape, TensorShape final_shape,
+                        PartialTensorShape partial_processing_shape,
+                        PartialTensorShape partial_final_shape,
+                        const StridedSliceShapeSpec& shape_spec,
+                        const std::vector<bool>& begins_are_dynamic,
+                        const std::vector<bool>& ends_are_dynamic) {
+    const TensorShape input_shape = ctx->InputShape(0);
+    xla::XlaOp slice = ctx->Input(0);
+    for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
+      OP_REQUIRES(ctx, strides[i] == 1,
+                  errors::InvalidArgument(
+                      "Strides have to be one when inputs are not constant."));
+    }
+    // Infer static output shape, reconsile unknown dimension with input dim
+    // size.
+    for (int64 i = 0; i < partial_final_shape.dims(); ++i) {
+      if (partial_final_shape.dim_size(i) == -1) {
+        // Use input shape shape_spec.
+        partial_final_shape.set_dim(
+            i,
+            input_shape.dim_size(shape_spec.output_to_processing_mapping[i]));
+      }
+    }
+
+    OP_REQUIRES(
+        ctx, partial_final_shape.AsTensorShape(&final_shape),
+        InvalidArgument("XLA can't deduce compile time constant output "
+                        "shape for strided slice: ",
+                        partial_final_shape.DebugString(),
+                        ", output shape must be a compile-time constant"));
+    for (int64 i = 0; i < partial_processing_shape.dims(); ++i) {
+      if (partial_processing_shape.dim_size(i) == -1) {
+        // Use input shape shape_spec.
+        partial_processing_shape.set_dim(i, input_shape.dim_size(i));
+      }
+    }
+    OP_REQUIRES(
+        ctx, partial_processing_shape.AsTensorShape(&processing_shape),
+        InvalidArgument("XLA can't deduce compile time constant processing "
+                        "shape for strided slice: ",
+                        partial_processing_shape.DebugString(),
+                        ", output shape must be a compile-time constant"));
+    // When inputs are not compile time constants, shape inference can only
+    // inference size 1 slice.
+    std::vector<int64> slice_sizes(input_shape.dims(), 1);
+    // If there is dynamic begin/end (and if the dimension is not shrunk), we
+    // need to use dynamic shape infrastructure -- we slice the output with
+    // full size, then call SetDimensionSize on the output. However, if we
+    // slice with the full size at a non-zero dimension we may get OOB access.
+    // To avoid that, we first pad the input to 2x before calling slice.
+    xla::PaddingConfig padding_config;
+    bool need_padding = false;
+    std::vector<bool> result_dims_are_dynamic;
+    for (int64 i = 0; i < input_shape.dims(); ++i) {
+      int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
+      bool shrink_axis_set = (1 << i) & shape_spec.shrink_axis_dense_mask;
+      auto* dims = padding_config.add_dimensions();
+      dims->set_edge_padding_low(0);
+
+      dims->set_interior_padding(0);
+      if ((begins_are_dynamic[sparse_index] ||
+           ends_are_dynamic[sparse_index]) &&
+          !shrink_axis_set) {
+        // Need to slice this dimension so pad first.
+        dims->set_edge_padding_high(input_shape.dim_size(i));
+        need_padding = true;
+        result_dims_are_dynamic.push_back(true);
+      } else {
+        dims->set_edge_padding_high(0);
+        result_dims_are_dynamic.push_back(false);
+      }
+    }
+
+    if (need_padding) {
+      // Pad input to 2x to avoid OOB access.
+      slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
+                       padding_config);
+    }
+    std::vector<xla::XlaOp> start_indices;
+    std::vector<xla::XlaOp> slice_sizes_dynamic;
+    xla::Shape input_xla_shape = ctx->InputXlaShape(0).ValueOrDie();
+    for (int64 i = 0; i < input_shape.dims(); ++i) {
+      bool begin_mask = (1 << i) & shape_spec.begin_dense_mask;
+      bool end_mask = (1 << i) & shape_spec.end_dense_mask;
+      auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
+      xla::XlaOp begin_index, end_index;
+      int64 sparse_index = shape_spec.processing_to_sparse_mapping[i];
+      bool xla_input_is_dynamic = input_xla_shape.is_dynamic_dimension(i);
+      xla::XlaOp dim_size;
+      if (xla_input_is_dynamic) {
+        dim_size = xla::GetDimensionSize(ctx->Input(0), i);
+        OP_REQUIRES(ctx, ctx->InputXlaType("begin") == xla::S32,
+                    errors::InvalidArgument("'begin shape has to be int32 when "
+                                            "indices to slice op are dynamic"));
+      } else {
+        dim_size =
+            xla::ConstantR0WithType(ctx->builder(), ctx->InputXlaType("begin"),
+                                    input_xla_shape.dimensions(i));
+      }
+      if (begin_mask) {
+        begin_index = zero;
+      } else {
+        begin_index = xla::Slice(ctx->Input("begin"), {sparse_index},
+                                 {sparse_index + 1}, {1});
+        begin_index = xla::Reshape(begin_index, {});
+        auto index_negative = xla::Lt(begin_index, zero);
+        auto wrapped_index = xla::Add(dim_size, begin_index);
+        // Wrap negative indices around.
+        begin_index = xla::Select(index_negative, wrapped_index, begin_index);
+      }
+      start_indices.push_back(begin_index);
+      if (end_mask) {
+        end_index = dim_size;
+      } else {
+        end_index = xla::Slice(ctx->Input("end"), {sparse_index},
+                               {sparse_index + 1}, {1});
+        end_index = xla::Reshape(end_index, {});
+        auto index_negative = xla::Lt(end_index, zero);
+        auto wrapped_index = xla::Add(dim_size, end_index);
+        end_index = xla::Select(index_negative, wrapped_index, end_index);
+      }
+      slice_sizes_dynamic.push_back(
+          xla::Max(xla::Sub(end_index, begin_index), zero));
+    }
+
+    slice =
+        xla::DynamicSlice(slice, start_indices, processing_shape.dim_sizes());
+
+    for (int64 i = 0; i < input_shape.dims(); ++i) {
+      if (result_dims_are_dynamic[i]) {
+        slice = xla::SetDimensionSize(slice, slice_sizes_dynamic[i], i);
+      }
+    }
+    slice = xla::Reshape(slice, final_shape.dim_sizes());
+    ctx->SetOutput(0, slice);
+  }
+
   void Compile(XlaOpKernelContext* ctx) override {
     const TensorShape input_shape = ctx->InputShape(0);
     const TensorShape begin_shape = ctx->InputShape("begin");
@@ -80,31 +220,33 @@ class StridedSliceOp : public XlaOpKernel {
     }
     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
                                             &strides_tensor));
-
-    TensorShape final_shape;
-    PartialTensorShape dummy_processing_shape, partial_final_shape;
+    TensorShape processing_shape, final_shape;
+    PartialTensorShape partial_processing_shape, partial_final_shape;
     bool dummy = false;
-    absl::InlinedVector<int64, 4> output_to_sparse_mapping;
-    absl::InlinedVector<int64, 4> output_to_processing_mapping;
+    StridedSliceShapeSpec shape_spec;
     OP_REQUIRES_OK(
         ctx,
         ValidateStridedSliceOp(
             begin_is_constant ? &begin_tensor : nullptr,
             end_is_constant ? &end_tensor : nullptr, strides_tensor,
             input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
-            shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape,
-            &dummy, &dummy, &dummy, &begin, &end, &strides,
-            &output_to_sparse_mapping, &output_to_processing_mapping));
-
-    OP_REQUIRES(
-        ctx, partial_final_shape.AsTensorShape(&final_shape),
-        InvalidArgument("XLA can't deduce compile time constant output "
-                        "shape for strided slice: ",
-                        partial_final_shape.DebugString(),
-                        ", output shape must be a compile-time constant"));
+            shrink_axis_mask_, &partial_processing_shape, &partial_final_shape,
+            &dummy, &dummy, &dummy, &begin, &end, &strides, &shape_spec));
 
     xla::XlaOp slice = ctx->Input(0);
+    std::vector<bool> begins_are_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
+    std::vector<bool> ends_are_dynamic;
+    OP_REQUIRES_OK(
+        ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
     if (begin_is_constant && end_is_constant) {
+      OP_REQUIRES(
+          ctx, partial_final_shape.AsTensorShape(&final_shape),
+          InvalidArgument("XLA can't deduce compile time constant output "
+                          "shape for strided slice: ",
+                          partial_final_shape.DebugString(),
+                          ", output shape must be a compile-time constant"));
       absl::InlinedVector<int64, 4> dimensions_to_reverse;
       absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
       for (int i = 0; i < begin.size(); ++i) {
@@ -129,12 +271,7 @@ class StridedSliceOp : public XlaOpKernel {
       auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0));
       OP_REQUIRES_OK(ctx, operand_shape_or.status());
       xla::Shape xla_shape = operand_shape_or.ValueOrDie();
-      std::vector<bool> begins_are_dynamic;
-      OP_REQUIRES_OK(
-          ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic));
-      std::vector<bool> ends_are_dynamic;
-      OP_REQUIRES_OK(
-          ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic));
+
       bool begins_are_static = absl::c_all_of(
           begins_are_dynamic, [](bool dynamic) { return !dynamic; });
       OP_REQUIRES(ctx, begins_are_static,
@@ -150,13 +287,13 @@ class StridedSliceOp : public XlaOpKernel {
       }
 
       for (int64 i = 0; i < final_shape.dims(); ++i) {
-        int64 input_index = output_to_processing_mapping[i];
+        int64 input_index = shape_spec.output_to_processing_mapping[i];
         if (input_index == -1) {
           continue;
         }
         bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index);
 
-        int64 sparse_index = output_to_sparse_mapping[i];
+        int64 sparse_index = shape_spec.output_to_sparse_mapping[i];
         bool end_is_dynamic =
             sparse_index == -1 ? false : ends_are_dynamic[sparse_index];
         bool backward_slice = sparse_index == -1
@@ -208,62 +345,9 @@ class StridedSliceOp : public XlaOpKernel {
       ctx->SetOutput(0, slice);
       return;
     } else {
-      // When output shape is fully defined, it must be a size one slice:
-      //
-      // 1. The number of output elements has to be equal to the number of input
-      // elements that are sliced.
-      // 2. The stride of the slice dimensions must be exact one.
-      int64 output_elements = final_shape.num_elements();
-
-      int64 input_elements_sliced = 1;
-      int64 slicing_dim_size = begin_shape.dim_size(0);
-      // We only support slicing major dimensions, so minor dimensions after
-      // slicing dimension are all sliced with their full sizes.
-      for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
-        input_elements_sliced *= input_shape.dim_size(d);
-      }
-
-      OP_REQUIRES(ctx, output_elements == input_elements_sliced,
-                  errors::InvalidArgument(
-                      "Dynamic indices of strided_slice_op have to be leading "
-                      "dimensions in the indices list."));
-
-      for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) {
-        OP_REQUIRES(
-            ctx, strides[i] == 1,
-            errors::InvalidArgument(
-                "Strides have to be one when inputs are not constant."));
-      }
-
-      // When inputs are not compile time constants, shape inference can only
-      // inference size 1 slice.
-      std::vector<int64> slice_sizes(slicing_dim_size, 1);
-      std::vector<xla::XlaOp> start_indices;
-      auto zero = xla::Zero(ctx->builder(), ctx->InputXlaType("begin"));
-      for (int64 d = 0; d < slicing_dim_size; ++d) {
-        auto index = xla::Slice(ctx->Input("begin"), {d}, {d + 1}, {1});
-        // Convert index to scalar.
-        index = xla::Reshape(index, {});
-        // Negative index: wrap it around with dimension size.
-        auto index_negative = xla::Lt(index, zero);
-        auto dim_size = xla::ConvertElementType(
-            xla::ConstantR0<int32>(ctx->builder(), input_shape.dim_size(d)),
-            ctx->InputXlaType("begin"));
-        auto wrapped_index = xla::Add(dim_size, index);
-        index = xla::Select(index_negative, wrapped_index, index);
-        start_indices.push_back(index);
-      }
-
-      for (int64 d = slicing_dim_size; d < input_shape.dims(); ++d) {
-        // For non-slice dims, naturally we get the full slice starting from 0.
-        slice_sizes.push_back(input_shape.dim_size(d));
-        start_indices.push_back(zero);
-      }
-
-      std::vector<int64> output_shape_dim_sizes;
-      slice = xla::DynamicSlice(slice, start_indices, slice_sizes);
-      slice = xla::Reshape(slice, final_shape.dim_sizes());
-      ctx->SetOutput(0, slice);
+      EmitDynamicSlice(ctx, strides, processing_shape, final_shape,
+                       partial_processing_shape, partial_final_shape,
+                       shape_spec, begins_are_dynamic, ends_are_dynamic);
     }
   }
 
@@ -308,10 +392,7 @@ class StridedSliceGradOp : public XlaOpKernel {
     absl::InlinedVector<int64, 4> begin;
     absl::InlinedVector<int64, 4> end;
     absl::InlinedVector<int64, 4> strides;
-
-    absl::InlinedVector<int64, 4> output_to_sparse_mapping;
-    absl::InlinedVector<int64, 4> output_to_processing_mapping;
-
+    StridedSliceShapeSpec shape_spec;
     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
                                             &strides_tensor));
     OP_REQUIRES_OK(
@@ -319,8 +400,7 @@ class StridedSliceGradOp : public XlaOpKernel {
                  nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
                  end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
                  &processing_shape, &final_shape, &dummy, &dummy, &dummy,
-                 &begin, &end, &strides, &output_to_sparse_mapping,
-                 &output_to_processing_mapping));
+                 &begin, &end, &strides, &shape_spec));
     for (int64 i = 0; i < processing_shape.dims(); ++i) {
       OP_REQUIRES(
           ctx, strides[i] == 1,
@@ -341,20 +421,20 @@ class StridedSliceGradOp : public XlaOpKernel {
       // Use grad shape, which is known, to update unknown processing shape.
       // Grad shape is the output of the ValidateStridedSliceOp function in
       // forward pass, thus we use output_to_processing_mapping.
-      if (output_to_processing_mapping[i] != -1) {
-        processing_shape.set_dim(output_to_processing_mapping[i],
+      if (shape_spec.output_to_processing_mapping[i] != -1) {
+        processing_shape.set_dim(shape_spec.output_to_processing_mapping[i],
                                  grad_shape.dimensions(i));
       }
 
       // Similarly, use output_to_sparse_mapping to find out corresponding
       // begin dim of the output, as indices for dynamic update slice.
-      int64 begin_dim = output_to_sparse_mapping[i];
+      int64 begin_dim = shape_spec.output_to_sparse_mapping[i];
       if (begin_dim != -1) {
         auto begin_index =
             xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
         auto begin_index_scalar = xla::Reshape(
             xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
-        begins[output_to_sparse_mapping[i]] = begin_index_scalar;
+        begins[shape_spec.output_to_sparse_mapping[i]] = begin_index_scalar;
       }
     }
     VLOG(1) << "processing_shape" << processing_shape.DebugString();
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index ef31735563d..35cd1c25b7d 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -3546,13 +3546,25 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
       }
       case HloOpcode::kTuple:
       case HloOpcode::kTranspose:
-      case HloOpcode::kGetTupleElement:
       case HloOpcode::kSlice:
       case HloOpcode::kBroadcast:
       case HloOpcode::kConcatenate:
       case HloOpcode::kReshape:
       case HloOpcode::kPad:
         break;
+      case HloOpcode::kGetTupleElement: {
+        // Rewrite parameter followed by gte into constants to avoid
+        // rematerializing the tuple parameter (could be very large).
+        int64 operand_handle = instr_proto->operand_ids(0);
+        TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
+                            LookUpInstructionByHandle(operand_handle));
+        TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
+                            StringToHloOpcode(operand_proto->opcode()));
+        if (operand_opcode == HloOpcode::kParameter) {
+          SetInstructionAsConstant(new_instr, id, new_shape, true);
+        }
+        break;
+      }
       case HloOpcode::kGetDimensionSize: {
         int64 dimension = instr_proto->dimensions(0);
         int64 operand_handle = instr_proto->operand_ids(0);
@@ -3646,6 +3658,18 @@ StatusOr<XlaComputation> XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) {
       should_visit_operand = false;
     }
 
+    if (opcode == HloOpcode::kGetTupleElement) {
+      int64 operand_handle = instr_proto->operand_ids(0);
+      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
+                          LookUpInstructionByHandle(operand_handle));
+      TF_ASSIGN_OR_RETURN(HloOpcode operand_opcode,
+                          StringToHloOpcode(operand_proto->opcode()));
+      if (operand_opcode == HloOpcode::kParameter) {
+        // Don't rematerialize the whole parameter if it's followed by a GTE.
+        should_visit_operand = false;
+      }
+    }
+
     if (opcode == HloOpcode::kSelect) {
       TF_ASSIGN_OR_RETURN(bool constant_predicate,
                           operand_is_constant(instr_proto, 0));
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index 1cf9a8cd013..126b684b8c7 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -16,6 +16,8 @@ limitations under the License.
 #include "tensorflow/core/util/strided_slice_op.h"
 
 #include <array>
+#include <iterator>
+
 #include "tensorflow/core/framework/bounds_check.h"
 #include "tensorflow/core/lib/core/status.h"
 
@@ -64,6 +66,7 @@ struct StridedSliceDenseSpec {
   // index. A -1 in this vector means there the index is not from the sparse
   // input.
   gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
+  gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
   // The dense indexed shrink mask is which processing dimensions
   // should be shrunk. For example, if foo.shape = (10,10,10,10)
   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
@@ -81,6 +84,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
   dense->begin.resize(dense->dims);
   dense->end.resize(dense->dims);
   dense->strides.resize(dense->dims);
+  dense->input_shape_gather_indices_sparse.resize(dense->dims);
   // What indices to get the final shape from.
   dense->begin_mask = 0;
   dense->end_mask = 0;
@@ -114,6 +118,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
           dense->end_mask |= (1 << full_index);
           dense->final_shape_gather_indices.push_back(full_index);
           dense->final_shape_gather_indices_sparse.push_back(-1);
+          dense->input_shape_gather_indices_sparse[full_index] = i;
         }
       } else if ((1 << i) & sparse.new_axis_mask) {
         dense->final_shape_gather_indices.push_back(kNewAxis);
@@ -153,6 +158,7 @@ static Status TF_MUST_USE_RESULT BuildDenseSpec(
           // from.
           dense->final_shape_gather_indices_sparse.push_back(i);
         }
+        dense->input_shape_gather_indices_sparse[full_index] = i;
         full_index++;
       }
     }
@@ -168,9 +174,7 @@ Status ValidateStridedSliceOp(
     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
-    gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+    gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
   const bool begin_is_wrong =
       begin_tensor != nullptr &&
       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
@@ -375,13 +379,18 @@ Status ValidateStridedSliceOp(
   // slices like foo[3,...] will reduce dimension by 1.
   // This cannot be done earlier, because it depends on Step 3.
   final_shape->Clear();
-  if (output_to_sparse_mapping != nullptr) {
-    output_to_sparse_mapping->clear();
+  if (shape_spec != nullptr) {
+    shape_spec->output_to_sparse_mapping.clear();
+    shape_spec->output_to_processing_mapping.clear();
+    shape_spec->processing_to_sparse_mapping.assign(
+        dense_spec.input_shape_gather_indices_sparse.begin(),
+        dense_spec.input_shape_gather_indices_sparse.end());
+
+    shape_spec->begin_dense_mask = dense_spec.begin_mask;
+    shape_spec->end_dense_mask = dense_spec.end_mask;
+    shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
   }
 
-  if (output_to_processing_mapping != nullptr) {
-    output_to_processing_mapping->clear();
-  }
   for (int64 dense_dim = 0;
        dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
     int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
@@ -389,22 +398,19 @@ Status ValidateStridedSliceOp(
         dense_spec.final_shape_gather_indices_sparse[dense_dim];
     if (gather_index >= 0) {
       final_shape->AddDim(processing_shape->dim_size(gather_index));
-      if (output_to_sparse_mapping != nullptr) {
-        output_to_sparse_mapping->push_back(sparse_index);
-      }
-      if (output_to_processing_mapping != nullptr) {
-        output_to_processing_mapping->push_back(gather_index);
+      if (shape_spec != nullptr) {
+        shape_spec->output_to_sparse_mapping.push_back(sparse_index);
+        shape_spec->output_to_processing_mapping.push_back(gather_index);
       }
     } else if (gather_index == kNewAxis) {
       final_shape->AddDim(1);
-      if (output_to_sparse_mapping != nullptr) {
-        output_to_sparse_mapping->push_back(-1);
-      }
-      if (output_to_processing_mapping != nullptr) {
-        output_to_processing_mapping->push_back(-1);
+      if (shape_spec != nullptr) {
+        shape_spec->output_to_sparse_mapping.push_back(-1);
+        shape_spec->output_to_processing_mapping.push_back(-1);
       }
     }
   }
+
   return Status::OK();
 }
 
@@ -416,16 +422,14 @@ Status ValidateStridedSliceOp(
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
     gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping) {
+    StridedSliceShapeSpec* shape_spec) {
   // Validate with PartialTensorShape output
   PartialTensorShape partial_processing_shape, partial_final_shape;
   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
       &partial_processing_shape, &partial_final_shape, is_identity,
-      is_simple_slice, slice_dim0, begin, end, strides,
-      output_to_sparse_mapping, output_to_processing_mapping));
+      is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
 
   // Verify that the output shapes are fully known
   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h
index 9e49477a9c3..dfb411a6065 100644
--- a/tensorflow/core/util/strided_slice_op.h
+++ b/tensorflow/core/util/strided_slice_op.h
@@ -23,6 +23,26 @@ limitations under the License.
 
 namespace tensorflow {
 
+struct StridedSliceShapeSpec {
+  // Begin mask canonlized in dense form.
+  int32 begin_dense_mask;
+  // End mask canonlized in dense form.
+  int32 end_dense_mask;
+  // Shrink axis mask canonlized in dense form.
+  int32 shrink_axis_dense_mask;
+  // output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
+  // index in the begin_tensor. If
+  // output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
+  // in sparse_mapping.
+  gtl::InlinedVector<int64, 4> output_to_sparse_mapping;
+  // output_to_processing_mapping is similar to output_to_sparse_mapping, but
+  // for processing shape.
+  gtl::InlinedVector<int64, 4> output_to_processing_mapping;
+  // processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
+  // dim index in the begin_tensor.
+  gtl::InlinedVector<int64, 4> processing_to_sparse_mapping;
+};
+
 // Runs validation on the strided slice op parameters.
 //
 // Is a separate translation unit from the kernel so that:
@@ -41,16 +61,6 @@ namespace tensorflow {
 // (-1). Any validation that can be done without complete information is
 // performed.
 //
-// This function changes the orders of dimensions, output_to_sparse_mapping and
-// output_to_processing_mapping are used to track the order change.
-//
-// output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
-// index in the begin_tensor. If
-// output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up in
-// sparse_mapping.
-//
-// output_to_processing_mapping is similar to output_to_sparse_mapping, but for
-// processing_shape.
 Status ValidateStridedSliceOp(
     const Tensor* begin_tensor, const Tensor* end_tensor,
     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
@@ -60,8 +70,7 @@ Status ValidateStridedSliceOp(
     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
     gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+    StridedSliceShapeSpec* shape_spec = nullptr);
 
 // Same as above, but the outputs are TensorShape, not PartialTensorShape
 Status ValidateStridedSliceOp(
@@ -72,8 +81,7 @@ Status ValidateStridedSliceOp(
     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
     gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
-    gtl::InlinedVector<int64, 4>* output_to_sparse_mapping = nullptr,
-    gtl::InlinedVector<int64, 4>* output_to_processing_mapping = nullptr);
+    StridedSliceShapeSpec* shape_spec = nullptr);
 
 }  // namespace tensorflow