Fold StridedSliceOp when input is defined by ShapeOp.

The pattern is common is TF python library like height = tf.shape(x)[1].
When x has some dynamic dimensions (typically batch dim), tf.shape can not be constant folded so height cannot be inferred as a constant.
This PR folds this kind of patterns to improve sub-shape constant folding.

Rename some testcases

Correctly handle negative strides

Add testcases for out of bound begin and end

clang-format

Address comments

Fix Windows build. Templated Lambda is not supported in MSVC.

Fix shrink_axis_mask with negative begin

Use canonicalization pattern instead of folder to better support unranked and dynamic output

Switch back to folder
This commit is contained in:
Tzu-Wei Sung 2021-02-03 14:50:43 -08:00
parent 18eaf4e8f1
commit e159982568
3 changed files with 338 additions and 16 deletions

View File

@ -14412,6 +14412,8 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
let hasFolder = 1;
let verifier = [{ return VerifyStridedSliceBase(*this); }];
let extraClassDeclaration = [{

View File

@ -1886,6 +1886,124 @@ bool StridedSliceOp::GetSlicedBoundRanges(
return true;
}
OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
// Fold StridedSlice operation if it extracts statically known dimensions.
//
// For example,
//
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
// %height = tf.StridedSlice(%shape, 1, 2, 1)
//
// In this case %height can be replaced with a constant 2.
//
// Or,
//
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
// %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
//
// In this case %spatial_shape can be replaced with a constant [2, 3].
// Input to strided slice op is defined by shape operation.
auto shape_op = input().getDefiningOp<ShapeOp>();
if (!shape_op) {
return {};
}
// `begin`, `end` and `strides` should be constant in order to infer static
// dimension.
DenseIntElementsAttr begin_attr, end_attr, strides_attr;
if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
!matchPattern(end(), m_Constant(&end_attr)) ||
!matchPattern(strides(), m_Constant(&strides_attr)) ||
begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
strides_attr.getNumElements() != 1) {
return {};
}
// Do not fold when `new_axis_mask` is set. It's likely to break the shape
// of output. Typically, `new_axis_mask` is not set in this canonicalization
// pattern.
if (new_axis_mask() != 0) return {};
auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
// Only ranked tensor can be folded.
if (!tensor_ty) return {};
int64_t rank = tensor_ty.getRank();
int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
// Canonicalize `begin` and `end` in case of negative index.
if (begin_int < 0) begin_int += rank;
if (end_int < 0) end_int += rank;
// Create `begin` and `end` from `*_mask`. Note that we don't care about
// `new_axis_mask` as it can be inferred from `output_ty`.
if (shrink_axis_mask() == 1) {
// When `shrink_axis_mask` is set, output is always a scalar so only
// one element is sliced.
end_int = begin_int + 1;
}
if (begin_mask() == 1) {
begin_int = (strides_int > 0) ? 0 : rank - 1;
}
if (end_mask() == 1) {
end_int = (strides_int > 0) ? rank : -1;
}
if (ellipsis_mask() == 1) {
begin_int = 0;
end_int = rank;
}
// It's possible that `begin` and `end` are out of bound. See
// https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
if (strides_int > 0) {
begin_int = std::min(begin_int, rank);
end_int = std::min(end_int, rank);
} else {
begin_int = std::min(begin_int, rank - 1);
end_int = std::min(end_int, rank - 1);
}
SmallVector<int64_t, 2> sub_shape;
// Only handle cases that have something to slice to avoid infinite for-loop.
if ((end_int > begin_int && strides_int > 0) ||
(end_int < begin_int && strides_int < 0)) {
// Extract sub-shape only if all of those dimensions are static.
for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
i += strides_int) {
if (tensor_ty.isDynamicDim(i)) {
return {};
}
sub_shape.push_back(tensor_ty.getDimSize(i));
}
}
// For unranked or dynamic output, we infer the output type to either a
// scalar or a vector based on `shrink_axis_mask` because we have rejected
// the case of `new_axis_mask` != 0.
auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
auto output_ty = output().getType().dyn_cast<RankedTensorType>();
if (!output_ty || !output_ty.hasStaticShape()) {
if (shrink_axis_mask() == 1) {
output_ty = RankedTensorType::get({}, output_elt_ty);
} else {
output_ty = RankedTensorType::get(
{static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
}
}
// Down-cast to 32 bit int if needed.
if (output_elt_ty.isInteger(32)) {
SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
[](int64_t d) { return static_cast<int32_t>(d); });
return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
}
return DenseIntElementsAttr::get(output_ty, sub_shape);
}
//===----------------------------------------------------------------------===//
// StridedSliceGradOp
//===----------------------------------------------------------------------===//

View File

@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso
}
// CHECK-LABEL: func @testPackShapeComputation
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
// Test dimensions sizes.
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
@ -526,26 +526,20 @@ func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>,
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[PACK0:.*]] = "tf.Pack"
// StridedSlice takes second dimension from the shape:
// begin = [1], end = [2], stride = [1]
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[PACK1:.*]] = "tf.Pack"
// Packed dimensions have higher rank than the reshape operand:
// [?, 1] vs [?, 1, 1]
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[PACK2:.*]] = "tf.Pack"
%16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[PACK1:.*]] = "tf.Pack"
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK: %[[PACK3:.*]] = "tf.Pack"
%18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
%19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
%20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK: %[[PACK2:.*]] = "tf.Pack"
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]]
return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
}
// CHECK-LABEL: testSelectScalarPred
@ -1373,3 +1367,211 @@ func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<
// CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
return %0, %1 : tensor<?xf32>, tensor<?xf32>
}
// CHECK-LABEL: testFoldStridedSliceShapeI32
func @testFoldStridedSliceShapeI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeI64
func @testFoldStridedSliceShapeI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi64>) {
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
return %3 : tensor<2xi64>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeDynamicOutput
func @testFoldStridedSliceShapeDynamicOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<?xi32>) {
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32
func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i32>) {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
return %3 : tensor<i32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64
func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i64>) {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
return %3 : tensor<i64>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput
func @testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<*xi32>) {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
return %3 : tensor<*xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1
func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
return %4 : tensor<i32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2
func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
%0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
return %4 : tensor<i32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testUnfoldedStridedSliceShape
func @testUnfoldedStridedSliceShape(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %4 : tensor<2xi32>
// CHECK: %[[SLICE:.*]] = "tf.StridedSlice"
// CHECK: return %[[SLICE]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask
func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %4 : tensor<2xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask
func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
return %3 : tensor<3xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides
func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %4 : tensor<2xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd
func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
%0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
return %3 : tensor<3xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides
func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
return %4 : tensor<1xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin
func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %4 : tensor<2xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask
func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
return %4 : tensor<2xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask
func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
return %4 : tensor<3xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice
func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (tensor<0xi32>) {
%0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
return %4 : tensor<0xi32>
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK: return %[[CST]]
}