[TF:MLIR] Fold PackOp if it computes tensor shape

Move shape related PackOp folding outside of ReshapeOp folding.

PiperOrigin-RevId: 322935716
Change-Id: I051fb8d7ed0c4507586d869b9c813f8b60634917
This commit is contained in:
Eugene Zhulenev 2020-07-23 22:09:37 -07:00 committed by TensorFlower Gardener
parent fd1481780b
commit 37deabbb75
4 changed files with 126 additions and 112 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -6334,6 +6334,8 @@ This is the opposite of `unpack`.
let verifier = [{
return Verify(*this);
}];
let hasFolder = 1;
}
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {

View File

@ -217,6 +217,97 @@ static LogicalResult Verify(PackOp op) {
return success();
}
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case %pack can be
// replaced with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Pack operation should pack at least two values.
if (values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (axis().getSExtValue() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op) return {};
// Input tensor, which shape is reconstructed by the pack operation.
Value tensor = shape_op.input();
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask().getSExtValue() != 0 ||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
slice_op.end_mask().getSExtValue() != 0 ||
slice_op.new_axis_mask().getSExtValue() != 0 ||
slice_op.shrink_axis_mask().getSExtValue() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(values().size() - 1);
for (Value operand : llvm::drop_begin(values(), 1)) {
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
// Replace %pack with %shape.
return slice_op.input();
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
@ -608,12 +699,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RedundantReshape>(context);
results.insert<RedundantReshape, ReshapeToSelfShape>(context);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
Value tensor = this->tensor();
Value shape = this->shape();
// Fold reshape if operand and result types are the same and all dimensions
// are statically known (no-op reshape).
@ -624,90 +714,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return tensor;
}
// Fold reshape if the shape is computed from the input tensor:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
//
// Where `...` are some statically known dimensions. In this case reshape is
// a no-op and can be replaced by %arg (assuming `...` are equal).
auto pack_op = dyn_cast_or_null<PackOp>(shape.getDefiningOp());
if (!pack_op || pack_op.values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (pack_op.axis().getSExtValue() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op =
dyn_cast_or_null<StridedSliceOp>(pack_op.values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op || shape_op.input() != tensor) return {};
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask().getSExtValue() != 0 ||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
slice_op.end_mask().getSExtValue() != 0 ||
slice_op.new_axis_mask().getSExtValue() != 0 ||
slice_op.shrink_axis_mask().getSExtValue() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_value = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(pack_op.values().size() - 1);
for (Value operand : llvm::drop_begin(pack_op.values(), 1)) {
if (auto dim = get_value(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_value(slice_op.begin(), /*expected_rank=*/1);
auto end = get_value(slice_op.end(), /*expected_rank=*/1);
auto strides = get_value(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != pack_op.values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
return tensor;
return {};
}
//===----------------------------------------------------------------------===//

View File

@ -377,6 +377,15 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
// CHECK: return %1 : tensor<2x8xi32>
}
// CHECK-LABEL: testReshapeToSelfShape
func @testReshapeToSelfShape(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<?x4xf32>) -> tensor<2xi32>
%1 = "tf.Reshape"(%arg0, %0) : (tensor<?x4xf32>, tensor<2xi32>) -> tensor<?x4xf32>
// CHECK: return %arg0 : tensor<?x4xf32>
return %1: tensor<?x4xf32>
}
// CHECK-LABEL: func @testReshapeNoOp
func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
@ -385,8 +394,8 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x
return %0 : tensor<2x4xf32>
}
// CHECK-LABEL: func @testReshapeNoOpShapeComputation
func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>) {
// CHECK-LABEL: func @testPackShapeComputation
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
// Test dimensions sizes.
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
@ -396,65 +405,56 @@ func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
// Fold reshape if the shape is computed from the input tensor:
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case reshape is
// a no-op and can be replaced by %arg (assuming `...` are equal).
// Where `...` are some statically known dimensions. In this case %pack can be
// replace with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Test Rank 2
// CHECK: %[[SHAPE0:.*]] = "tf.Shape"
%3 = "tf.Shape"(%arg0) : (tensor<?x1xf32>) -> tensor<2xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
%6 = "tf.Reshape"(%arg0, %5) : (tensor<?x1xf32>, tensor<2xi32>) -> tensor<?x1xf32>
// Test Rank 3.
// CHECK: %[[SHAPE1:.*]] = "tf.Shape"
%7 = "tf.Shape"(%arg1) : (tensor<?x1x2xf32>) -> tensor<3xi32>
%8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
%10 = "tf.Reshape"(%arg1, %9) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// Shape was taken from the op that is not reshaped in the end:
// Reshape(%arg1) vs Shape(%arg0)
%11 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%12 = "tf.Pack"(%11, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"
%13 = "tf.Reshape"(%arg1, %12) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// Packed dimensions have different order from the reshape operand:
// [?, 1, 2] vs [?, 2, 1]
%14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"
%16 = "tf.Reshape"(%arg1, %15) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x2x1xf32>
// CHECK: %[[PACK0:.*]] = "tf.Pack"
// StridedSlice takes second dimension from the shape:
// begin = [1], end = [2], stride = [1]
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE2:.*]] = "tf.Reshape"
%19 = "tf.Reshape"(%arg1, %18) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// CHECK: %[[PACK1:.*]] = "tf.Pack"
// Packed dimensions have higher rank than the reshape operand:
// [?, 1] vs [?, 1, 1]
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE3:.*]] = "tf.Reshape"
%22 = "tf.Reshape"(%arg0, %21) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
// CHECK: %[[PACK2:.*]] = "tf.Pack"
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%26 = "tf.Reshape"(%arg2, %25) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
// CHECK: %[[PACK3:.*]] = "tf.Pack"
// CHECK: return %arg0, %arg1, %[[RESHAPE0]], %[[RESHAPE1]], %[[RESHAPE2]], %[[RESHAPE3]]
return %6, %10, %13, %16, %19, %22, %26 : tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
}
// CHECK-LABEL: testSelectScalarPred

View File

@ -209,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
(TF_ReshapeOp $arg, $shape)>;
def IsSame : Constraint<CPred<"$0 == $1">>;
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
(replaceWithValue $arg0),
[(IsSame $arg0, $arg1)]>;
//===----------------------------------------------------------------------===//
// Select op patterns.
//===----------------------------------------------------------------------===//