From e159982568d290f15768182085742ade9d3ef6bf Mon Sep 17 00:00:00 2001
From: Tzu-Wei Sung <windqaq@gmail.com>
Date: Wed, 3 Feb 2021 14:50:43 -0800
Subject: [PATCH] Fold StridedSliceOp when input is defined by ShapeOp.

The pattern is common is TF python library like height = tf.shape(x)[1].
When x has some dynamic dimensions (typically batch dim), tf.shape can not be constant folded so height cannot be inferred as a constant.
This PR folds this kind of patterns to improve sub-shape constant folding.

Rename some testcases

Correctly handle negative strides

Add testcases for out of bound begin and end

clang-format

Address comments

Fix Windows build. Templated Lambda is not supported in MSVC.

Fix shrink_axis_mask with negative begin

Use canonicalization pattern instead of folder to better support unranked and dynamic output

Switch back to folder
---
 .../mlir/tensorflow/ir/tf_generated_ops.td    |   2 +
 .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 118 +++++++++
 .../mlir/tensorflow/tests/canonicalize.mlir   | 234 ++++++++++++++++--
 3 files changed, 338 insertions(+), 16 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 7cc9d5168b7..c472b249f9d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -14412,6 +14412,8 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
   TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
 
+  let hasFolder = 1;
+
   let verifier = [{ return VerifyStridedSliceBase(*this); }];
 
   let extraClassDeclaration = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
index 0f8a423124f..eba06465b50 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
@@ -1886,6 +1886,124 @@ bool StridedSliceOp::GetSlicedBoundRanges(
   return true;
 }
 
+OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
+  // Fold StridedSlice operation if it extracts statically known dimensions.
+  //
+  // For example,
+  //
+  //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
+  //   %height = tf.StridedSlice(%shape, 1, 2, 1)
+  //
+  // In this case %height can be replaced with a constant 2.
+  //
+  // Or,
+  //
+  //   %shape  = tf.Shape(%arg)                   // %arg: tensor<?x2x3x1xf32>
+  //   %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
+  //
+  // In this case %spatial_shape can be replaced with a constant [2, 3].
+
+  // Input to strided slice op is defined by shape operation.
+  auto shape_op = input().getDefiningOp<ShapeOp>();
+  if (!shape_op) {
+    return {};
+  }
+
+  // `begin`, `end` and `strides` should be constant in order to infer static
+  // dimension.
+  DenseIntElementsAttr begin_attr, end_attr, strides_attr;
+  if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
+      !matchPattern(end(), m_Constant(&end_attr)) ||
+      !matchPattern(strides(), m_Constant(&strides_attr)) ||
+      begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
+      strides_attr.getNumElements() != 1) {
+    return {};
+  }
+
+  // Do not fold when `new_axis_mask` is set. It's likely to break the shape
+  // of output. Typically, `new_axis_mask` is not set in this canonicalization
+  // pattern.
+  if (new_axis_mask() != 0) return {};
+
+  auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
+  // Only ranked tensor can be folded.
+  if (!tensor_ty) return {};
+
+  int64_t rank = tensor_ty.getRank();
+  int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
+  int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
+  int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
+
+  // Canonicalize `begin` and `end` in case of negative index.
+  if (begin_int < 0) begin_int += rank;
+  if (end_int < 0) end_int += rank;
+
+  // Create `begin` and `end` from `*_mask`. Note that we don't care about
+  // `new_axis_mask` as it can be inferred from `output_ty`.
+  if (shrink_axis_mask() == 1) {
+    // When `shrink_axis_mask` is set, output is always a scalar so only
+    // one element is sliced.
+    end_int = begin_int + 1;
+  }
+  if (begin_mask() == 1) {
+    begin_int = (strides_int > 0) ? 0 : rank - 1;
+  }
+  if (end_mask() == 1) {
+    end_int = (strides_int > 0) ? rank : -1;
+  }
+  if (ellipsis_mask() == 1) {
+    begin_int = 0;
+    end_int = rank;
+  }
+
+  // It's possible that `begin` and `end` are out of bound. See
+  // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
+  if (strides_int > 0) {
+    begin_int = std::min(begin_int, rank);
+    end_int = std::min(end_int, rank);
+  } else {
+    begin_int = std::min(begin_int, rank - 1);
+    end_int = std::min(end_int, rank - 1);
+  }
+
+  SmallVector<int64_t, 2> sub_shape;
+  // Only handle cases that have something to slice to avoid infinite for-loop.
+  if ((end_int > begin_int && strides_int > 0) ||
+      (end_int < begin_int && strides_int < 0)) {
+    // Extract sub-shape only if all of those dimensions are static.
+    for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
+         i += strides_int) {
+      if (tensor_ty.isDynamicDim(i)) {
+        return {};
+      }
+      sub_shape.push_back(tensor_ty.getDimSize(i));
+    }
+  }
+
+  // For unranked or dynamic output, we infer the output type to either a
+  // scalar or a vector based on `shrink_axis_mask` because we have rejected
+  // the case of `new_axis_mask` != 0.
+  auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
+  auto output_ty = output().getType().dyn_cast<RankedTensorType>();
+  if (!output_ty || !output_ty.hasStaticShape()) {
+    if (shrink_axis_mask() == 1) {
+      output_ty = RankedTensorType::get({}, output_elt_ty);
+    } else {
+      output_ty = RankedTensorType::get(
+          {static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
+    }
+  }
+
+  // Down-cast to 32 bit int if needed.
+  if (output_elt_ty.isInteger(32)) {
+    SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
+    std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
+                   [](int64_t d) { return static_cast<int32_t>(d); });
+    return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
+  }
+  return DenseIntElementsAttr::get(output_ty, sub_shape);
+}
+
 //===----------------------------------------------------------------------===//
 // StridedSliceGradOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
index e2a0552ef48..841e6ddb1cf 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
@@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso
 }
 
 // CHECK-LABEL: func @testPackShapeComputation
-func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>,  tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
+func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
   // Test dimensions sizes.
   %d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
   %d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
@@ -526,26 +526,20 @@ func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>,
   %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
   // CHECK: %[[PACK0:.*]] = "tf.Pack"
 
-  // StridedSlice takes second dimension from the shape:
-  //   begin = [1], end = [2], stride = [1]
-  %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
-  %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
-  // CHECK: %[[PACK1:.*]] = "tf.Pack"
-
   // Packed dimensions have higher rank than the reshape operand:
   //   [?, 1] vs [?, 1, 1]
-  %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
-  %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
-  // CHECK: %[[PACK2:.*]] = "tf.Pack"
+  %16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  %17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
+  // CHECK: %[[PACK1:.*]] = "tf.Pack"
 
   // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
-  %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
-  %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
-  %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
-  // CHECK: %[[PACK3:.*]] = "tf.Pack"
+  %18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
+  %19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+  %20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: %[[PACK2:.*]] = "tf.Pack"
 
-  // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
-  return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>,  tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
+  // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]]
+  return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
 }
 
 // CHECK-LABEL: testSelectScalarPred
@@ -1373,3 +1367,211 @@ func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<
   // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
   return %0, %1 : tensor<?xf32>, tensor<?xf32>
 }
+
+// CHECK-LABEL: testFoldStridedSliceShapeI32
+func @testFoldStridedSliceShapeI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %3 : tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeI64
+func @testFoldStridedSliceShapeI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi64>) {
+  %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
+  return %3 : tensor<2xi64>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeDynamicOutput
+func @testFoldStridedSliceShapeDynamicOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<?xi32>) {
+  %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
+  return %3 : tensor<?xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32
+func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  return %3 : tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64
+func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i64>) {
+  %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
+  return %3 : tensor<i64>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput
+func @testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<*xi32>) {
+  %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
+  return %3 : tensor<*xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1
+func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  return %4 : tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2
+func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
+  %0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+  return %4 : tensor<i32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testUnfoldedStridedSliceShape
+func @testUnfoldedStridedSliceShape(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %4 : tensor<2xi32>
+  // CHECK: %[[SLICE:.*]] = "tf.StridedSlice"
+  // CHECK: return %[[SLICE]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask
+func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %4 : tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask
+func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
+  %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+  return %3 : tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides
+func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %4 : tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd
+func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
+  %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+  return %3 : tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides
+func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) {
+  %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  return %4 : tensor<1xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin
+func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %4 : tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask
+func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
+  %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
+  return %4 : tensor<2xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask
+func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) {
+  %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
+  return %4 : tensor<3xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+  // CHECK: return %[[CST]]
+}
+
+// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice
+func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (tensor<0xi32>) {
+  %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
+  %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
+  %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
+  return %4 : tensor<0xi32>
+  // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
+  // CHECK: return %[[CST]]
+}