Fold StridedSliceOp when input is defined by ShapeOp.
The pattern is common is TF python library like height = tf.shape(x)[1]. When x has some dynamic dimensions (typically batch dim), tf.shape can not be constant folded so height cannot be inferred as a constant. This PR folds this kind of patterns to improve sub-shape constant folding. Rename some testcases Correctly handle negative strides Add testcases for out of bound begin and end clang-format Address comments Fix Windows build. Templated Lambda is not supported in MSVC. Fix shrink_axis_mask with negative begin Use canonicalization pattern instead of folder to better support unranked and dynamic output Switch back to folder
This commit is contained in:
parent
18eaf4e8f1
commit
e159982568
@ -14412,6 +14412,8 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let verifier = [{ return VerifyStridedSliceBase(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -1886,6 +1886,124 @@ bool StridedSliceOp::GetSlicedBoundRanges(
|
||||
return true;
|
||||
}
|
||||
|
||||
OpFoldResult StridedSliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Fold StridedSlice operation if it extracts statically known dimensions.
|
||||
//
|
||||
// For example,
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
|
||||
// %height = tf.StridedSlice(%shape, 1, 2, 1)
|
||||
//
|
||||
// In this case %height can be replaced with a constant 2.
|
||||
//
|
||||
// Or,
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // %arg: tensor<?x2x3x1xf32>
|
||||
// %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1)
|
||||
//
|
||||
// In this case %spatial_shape can be replaced with a constant [2, 3].
|
||||
|
||||
// Input to strided slice op is defined by shape operation.
|
||||
auto shape_op = input().getDefiningOp<ShapeOp>();
|
||||
if (!shape_op) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// `begin`, `end` and `strides` should be constant in order to infer static
|
||||
// dimension.
|
||||
DenseIntElementsAttr begin_attr, end_attr, strides_attr;
|
||||
if (!matchPattern(begin(), m_Constant(&begin_attr)) ||
|
||||
!matchPattern(end(), m_Constant(&end_attr)) ||
|
||||
!matchPattern(strides(), m_Constant(&strides_attr)) ||
|
||||
begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 ||
|
||||
strides_attr.getNumElements() != 1) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Do not fold when `new_axis_mask` is set. It's likely to break the shape
|
||||
// of output. Typically, `new_axis_mask` is not set in this canonicalization
|
||||
// pattern.
|
||||
if (new_axis_mask() != 0) return {};
|
||||
|
||||
auto tensor_ty = shape_op.input().getType().dyn_cast<RankedTensorType>();
|
||||
// Only ranked tensor can be folded.
|
||||
if (!tensor_ty) return {};
|
||||
|
||||
int64_t rank = tensor_ty.getRank();
|
||||
int64_t begin_int = begin_attr.getValue<APInt>(0).getSExtValue();
|
||||
int64_t end_int = end_attr.getValue<APInt>(0).getSExtValue();
|
||||
int64_t strides_int = strides_attr.getValue<APInt>(0).getSExtValue();
|
||||
|
||||
// Canonicalize `begin` and `end` in case of negative index.
|
||||
if (begin_int < 0) begin_int += rank;
|
||||
if (end_int < 0) end_int += rank;
|
||||
|
||||
// Create `begin` and `end` from `*_mask`. Note that we don't care about
|
||||
// `new_axis_mask` as it can be inferred from `output_ty`.
|
||||
if (shrink_axis_mask() == 1) {
|
||||
// When `shrink_axis_mask` is set, output is always a scalar so only
|
||||
// one element is sliced.
|
||||
end_int = begin_int + 1;
|
||||
}
|
||||
if (begin_mask() == 1) {
|
||||
begin_int = (strides_int > 0) ? 0 : rank - 1;
|
||||
}
|
||||
if (end_mask() == 1) {
|
||||
end_int = (strides_int > 0) ? rank : -1;
|
||||
}
|
||||
if (ellipsis_mask() == 1) {
|
||||
begin_int = 0;
|
||||
end_int = rank;
|
||||
}
|
||||
|
||||
// It's possible that `begin` and `end` are out of bound. See
|
||||
// https://docs.python.org/3/library/stdtypes.html#common-sequence-operations.
|
||||
if (strides_int > 0) {
|
||||
begin_int = std::min(begin_int, rank);
|
||||
end_int = std::min(end_int, rank);
|
||||
} else {
|
||||
begin_int = std::min(begin_int, rank - 1);
|
||||
end_int = std::min(end_int, rank - 1);
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> sub_shape;
|
||||
// Only handle cases that have something to slice to avoid infinite for-loop.
|
||||
if ((end_int > begin_int && strides_int > 0) ||
|
||||
(end_int < begin_int && strides_int < 0)) {
|
||||
// Extract sub-shape only if all of those dimensions are static.
|
||||
for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int;
|
||||
i += strides_int) {
|
||||
if (tensor_ty.isDynamicDim(i)) {
|
||||
return {};
|
||||
}
|
||||
sub_shape.push_back(tensor_ty.getDimSize(i));
|
||||
}
|
||||
}
|
||||
|
||||
// For unranked or dynamic output, we infer the output type to either a
|
||||
// scalar or a vector based on `shrink_axis_mask` because we have rejected
|
||||
// the case of `new_axis_mask` != 0.
|
||||
auto output_elt_ty = output().getType().cast<ShapedType>().getElementType();
|
||||
auto output_ty = output().getType().dyn_cast<RankedTensorType>();
|
||||
if (!output_ty || !output_ty.hasStaticShape()) {
|
||||
if (shrink_axis_mask() == 1) {
|
||||
output_ty = RankedTensorType::get({}, output_elt_ty);
|
||||
} else {
|
||||
output_ty = RankedTensorType::get(
|
||||
{static_cast<int64_t>(sub_shape.size())}, output_elt_ty);
|
||||
}
|
||||
}
|
||||
|
||||
// Down-cast to 32 bit int if needed.
|
||||
if (output_elt_ty.isInteger(32)) {
|
||||
SmallVector<int32_t, 2> sub_shape_i32(sub_shape.size());
|
||||
std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(),
|
||||
[](int64_t d) { return static_cast<int32_t>(d); });
|
||||
return DenseIntElementsAttr::get(output_ty, sub_shape_i32);
|
||||
}
|
||||
return DenseIntElementsAttr::get(output_ty, sub_shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StridedSliceGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testPackShapeComputation
|
||||
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
|
||||
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
|
||||
// Test dimensions sizes.
|
||||
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -526,26 +526,20 @@ func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>,
|
||||
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[PACK0:.*]] = "tf.Pack"
|
||||
|
||||
// StridedSlice takes second dimension from the shape:
|
||||
// begin = [1], end = [2], stride = [1]
|
||||
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[PACK1:.*]] = "tf.Pack"
|
||||
|
||||
// Packed dimensions have higher rank than the reshape operand:
|
||||
// [?, 1] vs [?, 1, 1]
|
||||
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[PACK2:.*]] = "tf.Pack"
|
||||
%16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[PACK1:.*]] = "tf.Pack"
|
||||
|
||||
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
|
||||
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
|
||||
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: %[[PACK3:.*]] = "tf.Pack"
|
||||
%18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
|
||||
%19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
%20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: %[[PACK2:.*]] = "tf.Pack"
|
||||
|
||||
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
|
||||
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
|
||||
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]]
|
||||
return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectScalarPred
|
||||
@ -1373,3 +1367,211 @@ func @testUnpackAndCwiseUnary(%arg0: tensor<?x2xf32>) -> (tensor<?xf32>, tensor<
|
||||
// CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1
|
||||
return %0, %1 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeI32
|
||||
func @testFoldStridedSliceShapeI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %3 : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeI64
|
||||
func @testFoldStridedSliceShapeI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi64>) {
|
||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64>
|
||||
return %3 : tensor<2xi64>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeDynamicOutput
|
||||
func @testFoldStridedSliceShapeDynamicOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<?xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32
|
||||
func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i32>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64
|
||||
func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor<?x1x2x?xf32>) -> (tensor<i64>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi64>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i64>
|
||||
return %3 : tensor<i64>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput
|
||||
func @testFoldStridedSliceShapeWithShrinkAxisMaskUnrankedOutput(%arg0: tensor<?x1x2x?xf32>) -> (tensor<*xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
return %3 : tensor<*xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1
|
||||
func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
return %4 : tensor<i32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2
|
||||
func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor<?x1x2x3xf32>) -> (tensor<i32>) {
|
||||
%0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
return %4 : tensor<i32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testUnfoldedStridedSliceShape
|
||||
func @testUnfoldedStridedSliceShape(%arg0: tensor<?x1x2x?xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x?xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
// CHECK: %[[SLICE:.*]] = "tf.StridedSlice"
|
||||
// CHECK: return %[[SLICE]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask
|
||||
func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask
|
||||
func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %3 : tensor<3xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides
|
||||
func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd
|
||||
func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor<?x1x2x3xf32>) -> (tensor<3xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %3 : tensor<3xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides
|
||||
func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %4 : tensor<1xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin
|
||||
func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask
|
||||
func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor<?x1x2x3xf32>) -> (tensor<2xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
|
||||
return %4 : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask
|
||||
func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %4 : tensor<3xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice
|
||||
func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor<?x1x2x3xf32>) -> (tensor<0xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1x2x3xf32>) -> tensor<4xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32>
|
||||
return %4 : tensor<0xi32>
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user