[TF:MLIR] Fold PackOp if it computes tensor shape
Move shape related PackOp folding outside of ReshapeOp folding. PiperOrigin-RevId: 322935716 Change-Id: I051fb8d7ed0c4507586d869b9c813f8b60634917
This commit is contained in:
parent
fd1481780b
commit
37deabbb75
tensorflow/compiler/mlir/tensorflow
@ -6334,6 +6334,8 @@ This is the opposite of `unpack`.
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
|
@ -217,6 +217,97 @@ static LogicalResult Verify(PackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Fold pack operation if it computes the input tensor shape:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
|
||||
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case %pack can be
|
||||
// replaced with a %shape. This is a common pattern in models with a dynamic
|
||||
// batch size.
|
||||
|
||||
// Pack operation should pack at least two values.
|
||||
if (values().size() < 2) return {};
|
||||
|
||||
// Dimensions packed along axis = 0 (pack scalars into vector).
|
||||
if (axis().getSExtValue() != 0) return {};
|
||||
|
||||
// First packed value is defined by a strided slice operation.
|
||||
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
|
||||
if (!slice_op) return {};
|
||||
|
||||
// Input to the slice op is defined by shape operation.
|
||||
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
|
||||
if (!shape_op) return {};
|
||||
|
||||
// Input tensor, which shape is reconstructed by the pack operation.
|
||||
Value tensor = shape_op.input();
|
||||
|
||||
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
|
||||
// scalar value from input vector).
|
||||
if (slice_op.begin_mask().getSExtValue() != 0 ||
|
||||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
|
||||
slice_op.end_mask().getSExtValue() != 0 ||
|
||||
slice_op.new_axis_mask().getSExtValue() != 0 ||
|
||||
slice_op.shrink_axis_mask().getSExtValue() != 1)
|
||||
return {};
|
||||
|
||||
// Returns a value if the `value` is defined by a ConstOp with a single
|
||||
// integer element in it and has an expected rank.
|
||||
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
|
||||
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
|
||||
if (!const_op) return None;
|
||||
|
||||
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
|
||||
if (!value_attr || value_attr.getNumElements() != 1) return None;
|
||||
|
||||
auto value_ty = value_attr.getType();
|
||||
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
|
||||
|
||||
auto splat = value_attr.getSplatValue<IntegerAttr>();
|
||||
return splat.getValue().getSExtValue();
|
||||
};
|
||||
|
||||
// All other packed values are scalar constants.
|
||||
SmallVector<int64_t, 4> packed_dims;
|
||||
packed_dims.reserve(values().size() - 1);
|
||||
for (Value operand : llvm::drop_begin(values(), 1)) {
|
||||
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
|
||||
packed_dims.push_back(*dim);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Slice exactly the first shape dimension:
|
||||
// begin = [0] end = [1], strides = [1]
|
||||
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
|
||||
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
|
||||
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
|
||||
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
|
||||
*begin != 0 || *end != 1 || *strides != 1)
|
||||
return {};
|
||||
|
||||
// First tensor dimension is dynamic.
|
||||
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
|
||||
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
|
||||
!arg_ty.isDynamicDim(0))
|
||||
return {};
|
||||
|
||||
// Argument tensor rank is equal to the number of packed dimensions.
|
||||
if (arg_ty.getRank() != values().size()) return {};
|
||||
|
||||
// All other dimensions are statically known and equal to packed dims.
|
||||
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
|
||||
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
|
||||
return {};
|
||||
|
||||
// Replace %pack with %shape.
|
||||
return slice_op.input();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -608,12 +699,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RedundantReshape>(context);
|
||||
results.insert<RedundantReshape, ReshapeToSelfShape>(context);
|
||||
}
|
||||
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
Value tensor = this->tensor();
|
||||
Value shape = this->shape();
|
||||
|
||||
// Fold reshape if operand and result types are the same and all dimensions
|
||||
// are statically known (no-op reshape).
|
||||
@ -624,90 +714,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Fold reshape if the shape is computed from the input tensor:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
|
||||
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case reshape is
|
||||
// a no-op and can be replaced by %arg (assuming `...` are equal).
|
||||
auto pack_op = dyn_cast_or_null<PackOp>(shape.getDefiningOp());
|
||||
if (!pack_op || pack_op.values().size() < 2) return {};
|
||||
|
||||
// Dimensions packed along axis = 0 (pack scalars into vector).
|
||||
if (pack_op.axis().getSExtValue() != 0) return {};
|
||||
|
||||
// First packed value is defined by a strided slice operation.
|
||||
auto slice_op =
|
||||
dyn_cast_or_null<StridedSliceOp>(pack_op.values()[0].getDefiningOp());
|
||||
if (!slice_op) return {};
|
||||
|
||||
// Input to the slice op is defined by shape operation.
|
||||
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
|
||||
if (!shape_op || shape_op.input() != tensor) return {};
|
||||
|
||||
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
|
||||
// scalar value from input vector).
|
||||
if (slice_op.begin_mask().getSExtValue() != 0 ||
|
||||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
|
||||
slice_op.end_mask().getSExtValue() != 0 ||
|
||||
slice_op.new_axis_mask().getSExtValue() != 0 ||
|
||||
slice_op.shrink_axis_mask().getSExtValue() != 1)
|
||||
return {};
|
||||
|
||||
// Returns a value if the `value` is defined by a ConstOp with a single
|
||||
// integer element in it and has an expected rank.
|
||||
auto get_value = [](Value value, int expected_rank) -> Optional<int64_t> {
|
||||
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
|
||||
if (!const_op) return None;
|
||||
|
||||
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
|
||||
if (!value_attr || value_attr.getNumElements() != 1) return None;
|
||||
|
||||
auto value_ty = value_attr.getType();
|
||||
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
|
||||
|
||||
auto splat = value_attr.getSplatValue<IntegerAttr>();
|
||||
return splat.getValue().getSExtValue();
|
||||
};
|
||||
|
||||
// All other packed values are scalar constants.
|
||||
SmallVector<int64_t, 4> packed_dims;
|
||||
packed_dims.reserve(pack_op.values().size() - 1);
|
||||
for (Value operand : llvm::drop_begin(pack_op.values(), 1)) {
|
||||
if (auto dim = get_value(operand, /*expected_rank=*/0)) {
|
||||
packed_dims.push_back(*dim);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Slice exactly the first shape dimension:
|
||||
// begin = [0] end = [1], strides = [1]
|
||||
auto begin = get_value(slice_op.begin(), /*expected_rank=*/1);
|
||||
auto end = get_value(slice_op.end(), /*expected_rank=*/1);
|
||||
auto strides = get_value(slice_op.strides(), /*expected_rank=*/1);
|
||||
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
|
||||
*begin != 0 || *end != 1 || *strides != 1)
|
||||
return {};
|
||||
|
||||
// First tensor dimension is dynamic.
|
||||
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
|
||||
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
|
||||
!arg_ty.isDynamicDim(0))
|
||||
return {};
|
||||
|
||||
// Argument tensor rank is equal to the number of packed dimensions.
|
||||
if (arg_ty.getRank() != pack_op.values().size()) return {};
|
||||
|
||||
// All other dimensions are statically known and equal to packed dims.
|
||||
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
|
||||
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
|
||||
return {};
|
||||
|
||||
return tensor;
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -377,6 +377,15 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
|
||||
// CHECK: return %1 : tensor<2x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReshapeToSelfShape
|
||||
func @testReshapeToSelfShape(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x4xf32>) -> tensor<2xi32>
|
||||
%1 = "tf.Reshape"(%arg0, %0) : (tensor<?x4xf32>, tensor<2xi32>) -> tensor<?x4xf32>
|
||||
|
||||
// CHECK: return %arg0 : tensor<?x4xf32>
|
||||
return %1: tensor<?x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testReshapeNoOp
|
||||
func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
|
||||
@ -385,8 +394,8 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testReshapeNoOpShapeComputation
|
||||
func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>) {
|
||||
// CHECK-LABEL: func @testPackShapeComputation
|
||||
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
|
||||
// Test dimensions sizes.
|
||||
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -396,65 +405,56 @@ func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
|
||||
// Fold reshape if the shape is computed from the input tensor:
|
||||
// Fold pack operation if it computes the input tensor shape:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
|
||||
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
|
||||
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case reshape is
|
||||
// a no-op and can be replaced by %arg (assuming `...` are equal).
|
||||
// Where `...` are some statically known dimensions. In this case %pack can be
|
||||
// replace with a %shape. This is a common pattern in models with a dynamic
|
||||
// batch size.
|
||||
|
||||
// Test Rank 2
|
||||
// CHECK: %[[SHAPE0:.*]] = "tf.Shape"
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1xf32>) -> tensor<2xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
|
||||
%6 = "tf.Reshape"(%arg0, %5) : (tensor<?x1xf32>, tensor<2xi32>) -> tensor<?x1xf32>
|
||||
|
||||
// Test Rank 3.
|
||||
|
||||
// CHECK: %[[SHAPE1:.*]] = "tf.Shape"
|
||||
%7 = "tf.Shape"(%arg1) : (tensor<?x1x2xf32>) -> tensor<3xi32>
|
||||
%8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
%10 = "tf.Reshape"(%arg1, %9) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
|
||||
// Shape was taken from the op that is not reshaped in the end:
|
||||
// Reshape(%arg1) vs Shape(%arg0)
|
||||
%11 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%12 = "tf.Pack"(%11, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"
|
||||
%13 = "tf.Reshape"(%arg1, %12) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
|
||||
// Packed dimensions have different order from the reshape operand:
|
||||
// [?, 1, 2] vs [?, 2, 1]
|
||||
%14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"
|
||||
%16 = "tf.Reshape"(%arg1, %15) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x2x1xf32>
|
||||
// CHECK: %[[PACK0:.*]] = "tf.Pack"
|
||||
|
||||
// StridedSlice takes second dimension from the shape:
|
||||
// begin = [1], end = [2], stride = [1]
|
||||
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE2:.*]] = "tf.Reshape"
|
||||
%19 = "tf.Reshape"(%arg1, %18) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
// CHECK: %[[PACK1:.*]] = "tf.Pack"
|
||||
|
||||
// Packed dimensions have higher rank than the reshape operand:
|
||||
// [?, 1] vs [?, 1, 1]
|
||||
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE3:.*]] = "tf.Reshape"
|
||||
%22 = "tf.Reshape"(%arg0, %21) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[PACK2:.*]] = "tf.Pack"
|
||||
|
||||
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
|
||||
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
|
||||
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%26 = "tf.Reshape"(%arg2, %25) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
// CHECK: %[[PACK3:.*]] = "tf.Pack"
|
||||
|
||||
// CHECK: return %arg0, %arg1, %[[RESHAPE0]], %[[RESHAPE1]], %[[RESHAPE2]], %[[RESHAPE3]]
|
||||
return %6, %10, %13, %16, %19, %22, %26 : tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>
|
||||
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
|
||||
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectScalarPred
|
||||
|
@ -209,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
|
||||
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
|
||||
(TF_ReshapeOp $arg, $shape)>;
|
||||
|
||||
def IsSame : Constraint<CPred<"$0 == $1">>;
|
||||
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
|
||||
(replaceWithValue $arg0),
|
||||
[(IsSame $arg0, $arg1)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Select op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user