diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 7cc9d5168b7..c472b249f9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -14412,6 +14412,8 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; + let hasFolder = 1; + let verifier = [{ return VerifyStridedSliceBase(*this); }]; let extraClassDeclaration = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 0f8a423124f..eba06465b50 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1886,6 +1886,124 @@ bool StridedSliceOp::GetSlicedBoundRanges( return true; } +OpFoldResult StridedSliceOp::fold(ArrayRef operands) { + // Fold StridedSlice operation if it extracts statically known dimensions. + // + // For example, + // + // %shape = tf.Shape(%arg) // %arg: tensor + // %height = tf.StridedSlice(%shape, 1, 2, 1) + // + // In this case %height can be replaced with a constant 2. + // + // Or, + // + // %shape = tf.Shape(%arg) // %arg: tensor + // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1) + // + // In this case %spatial_shape can be replaced with a constant [2, 3]. + + // Input to strided slice op is defined by shape operation. + auto shape_op = input().getDefiningOp(); + if (!shape_op) { + return {}; + } + + // `begin`, `end` and `strides` should be constant in order to infer static + // dimension. + DenseIntElementsAttr begin_attr, end_attr, strides_attr; + if (!matchPattern(begin(), m_Constant(&begin_attr)) || + !matchPattern(end(), m_Constant(&end_attr)) || + !matchPattern(strides(), m_Constant(&strides_attr)) || + begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 || + strides_attr.getNumElements() != 1) { + return {}; + } + + // Do not fold when `new_axis_mask` is set. It's likely to break the shape + // of output. Typically, `new_axis_mask` is not set in this canonicalization + // pattern. + if (new_axis_mask() != 0) return {}; + + auto tensor_ty = shape_op.input().getType().dyn_cast(); + // Only ranked tensor can be folded. + if (!tensor_ty) return {}; + + int64_t rank = tensor_ty.getRank(); + int64_t begin_int = begin_attr.getValue(0).getSExtValue(); + int64_t end_int = end_attr.getValue(0).getSExtValue(); + int64_t strides_int = strides_attr.getValue(0).getSExtValue(); + + // Canonicalize `begin` and `end` in case of negative index. + if (begin_int < 0) begin_int += rank; + if (end_int < 0) end_int += rank; + + // Create `begin` and `end` from `*_mask`. Note that we don't care about + // `new_axis_mask` as it can be inferred from `output_ty`. + if (shrink_axis_mask() == 1) { + // When `shrink_axis_mask` is set, output is always a scalar so only + // one element is sliced. + end_int = begin_int + 1; + } + if (begin_mask() == 1) { + begin_int = (strides_int > 0) ? 0 : rank - 1; + } + if (end_mask() == 1) { + end_int = (strides_int > 0) ? rank : -1; + } + if (ellipsis_mask() == 1) { + begin_int = 0; + end_int = rank; + } + + // It's possible that `begin` and `end` are out of bound. See + // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations. + if (strides_int > 0) { + begin_int = std::min(begin_int, rank); + end_int = std::min(end_int, rank); + } else { + begin_int = std::min(begin_int, rank - 1); + end_int = std::min(end_int, rank - 1); + } + + SmallVector sub_shape; + // Only handle cases that have something to slice to avoid infinite for-loop. + if ((end_int > begin_int && strides_int > 0) || + (end_int < begin_int && strides_int < 0)) { + // Extract sub-shape only if all of those dimensions are static. + for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int; + i += strides_int) { + if (tensor_ty.isDynamicDim(i)) { + return {}; + } + sub_shape.push_back(tensor_ty.getDimSize(i)); + } + } + + // For unranked or dynamic output, we infer the output type to either a + // scalar or a vector based on `shrink_axis_mask` because we have rejected + // the case of `new_axis_mask` != 0. + auto output_elt_ty = output().getType().cast().getElementType(); + auto output_ty = output().getType().dyn_cast(); + if (!output_ty || !output_ty.hasStaticShape()) { + if (shrink_axis_mask() == 1) { + output_ty = RankedTensorType::get({}, output_elt_ty); + } else { + output_ty = RankedTensorType::get( + {static_cast(sub_shape.size())}, output_elt_ty); + } + } + + // Down-cast to 32 bit int if needed. + if (output_elt_ty.isInteger(32)) { + SmallVector sub_shape_i32(sub_shape.size()); + std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(), + [](int64_t d) { return static_cast(d); }); + return DenseIntElementsAttr::get(output_ty, sub_shape_i32); + } + return DenseIntElementsAttr::get(output_ty, sub_shape); +} + //===----------------------------------------------------------------------===// // StridedSliceGradOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index e2a0552ef48..841e6ddb1cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso } // CHECK-LABEL: func @testPackShapeComputation -func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { +func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. %d1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %d2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -526,26 +526,20 @@ func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> // CHECK: %[[PACK0:.*]] = "tf.Pack" - // StridedSlice takes second dimension from the shape: - // begin = [1], end = [2], stride = [1] - %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[PACK1:.*]] = "tf.Pack" - // Packed dimensions have higher rank than the reshape operand: // [?, 1] vs [?, 1, 1] - %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[PACK2:.*]] = "tf.Pack" + %16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // CHECK: %[[PACK1:.*]] = "tf.Pack" // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass - %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> - %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> - %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> - // CHECK: %[[PACK3:.*]] = "tf.Pack" + %18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> + %19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> + %20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> + // CHECK: %[[PACK2:.*]] = "tf.Pack" - // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]] - return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]] + return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> } // CHECK-LABEL: testSelectScalarPred @@ -1373,3 +1367,211 @@ func @testUnpackAndCwiseUnary(%arg0: tensor) -> (tensor, tensor< // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1 return %0, %1 : tensor, tensor } + +// CHECK-LABEL: testFoldStridedSliceShapeI32 +func @testFoldStridedSliceShapeI32(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %3 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeI64 +func @testFoldStridedSliceShapeI64(%arg0: tensor) -> (tensor<2xi64>) { + %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + return %3 : tensor<2xi64> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeDynamicOutput +func @testFoldStridedSliceShapeDynamicOutput(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %3 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32 +func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %3 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64 +func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %3 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput +func @testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput(%arg0: tensor) -> (tensor<*xi32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> + return %3 : tensor<*xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor<*xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1 +func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %4 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2 +func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %4 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testUnfoldedStridedSliceShape +func @testUnfoldedStridedSliceShape(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[SLICE:.*]] = "tf.StridedSlice" + // CHECK: return %[[SLICE]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask +func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask +func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides +func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd +func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides +func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %4 : tensor<1xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin +func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask +func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask +func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %4 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice +func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor) -> (tensor<0xi32>) { + %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + return %4 : tensor<0xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: return %[[CST]] +}