diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 35a6b0e2343..0851975e8e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -6334,6 +6334,8 @@ This is the opposite of `unpack`. let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 0d9b2610492..785b0bac820 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -217,6 +217,97 @@ static LogicalResult Verify(PackOp op) { return success(); } +OpFoldResult PackOp::fold(ArrayRef operands) { + // Fold pack operation if it computes the input tensor shape: + // + // %shape = tf.Shape(%arg) // [? x ...] + // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value + // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] + // + // Where `...` are some statically known dimensions. In this case %pack can be + // replaced with a %shape. This is a common pattern in models with a dynamic + // batch size. + + // Pack operation should pack at least two values. + if (values().size() < 2) return {}; + + // Dimensions packed along axis = 0 (pack scalars into vector). + if (axis().getSExtValue() != 0) return {}; + + // First packed value is defined by a strided slice operation. + auto slice_op = dyn_cast_or_null(values()[0].getDefiningOp()); + if (!slice_op) return {}; + + // Input to the slice op is defined by shape operation. + auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); + if (!shape_op) return {}; + + // Input tensor, which shape is reconstructed by the pack operation. + Value tensor = shape_op.input(); + + // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing + // scalar value from input vector). + if (slice_op.begin_mask().getSExtValue() != 0 || + slice_op.ellipsis_mask().getSExtValue() != 0 || + slice_op.end_mask().getSExtValue() != 0 || + slice_op.new_axis_mask().getSExtValue() != 0 || + slice_op.shrink_axis_mask().getSExtValue() != 1) + return {}; + + // Returns a value if the `value` is defined by a ConstOp with a single + // integer element in it and has an expected rank. + auto get_const_int = [](Value value, int expected_rank) -> Optional { + auto const_op = dyn_cast_or_null(value.getDefiningOp()); + if (!const_op) return None; + + auto value_attr = const_op.value().dyn_cast(); + if (!value_attr || value_attr.getNumElements() != 1) return None; + + auto value_ty = value_attr.getType(); + if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; + + auto splat = value_attr.getSplatValue(); + return splat.getValue().getSExtValue(); + }; + + // All other packed values are scalar constants. + SmallVector packed_dims; + packed_dims.reserve(values().size() - 1); + for (Value operand : llvm::drop_begin(values(), 1)) { + if (auto dim = get_const_int(operand, /*expected_rank=*/0)) { + packed_dims.push_back(*dim); + } else { + return {}; + } + } + + // Slice exactly the first shape dimension: + // begin = [0] end = [1], strides = [1] + auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1); + auto end = get_const_int(slice_op.end(), /*expected_rank=*/1); + auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1); + if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() || + *begin != 0 || *end != 1 || *strides != 1) + return {}; + + // First tensor dimension is dynamic. + auto arg_ty = tensor.getType().dyn_cast(); + if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 || + !arg_ty.isDynamicDim(0)) + return {}; + + // Argument tensor rank is equal to the number of packed dimensions. + if (arg_ty.getRank() != values().size()) return {}; + + // All other dimensions are statically known and equal to packed dims. + auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); + if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin())) + return {}; + + // Replace %pack with %shape. + return slice_op.input(); +} + //===----------------------------------------------------------------------===// // PadOp //===----------------------------------------------------------------------===// @@ -608,12 +699,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } OpFoldResult ReshapeOp::fold(ArrayRef operands) { Value tensor = this->tensor(); - Value shape = this->shape(); // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). @@ -624,90 +714,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return tensor; } - // Fold reshape if the shape is computed from the input tensor: - // - // %shape = tf.Shape(%arg) // [? x ...] - // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value - // %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] - // %reshape = tf.Reshape(%arg, %new_shape) // this is no-op - // - // Where `...` are some statically known dimensions. In this case reshape is - // a no-op and can be replaced by %arg (assuming `...` are equal). - auto pack_op = dyn_cast_or_null(shape.getDefiningOp()); - if (!pack_op || pack_op.values().size() < 2) return {}; - - // Dimensions packed along axis = 0 (pack scalars into vector). - if (pack_op.axis().getSExtValue() != 0) return {}; - - // First packed value is defined by a strided slice operation. - auto slice_op = - dyn_cast_or_null(pack_op.values()[0].getDefiningOp()); - if (!slice_op) return {}; - - // Input to the slice op is defined by shape operation. - auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); - if (!shape_op || shape_op.input() != tensor) return {}; - - // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing - // scalar value from input vector). - if (slice_op.begin_mask().getSExtValue() != 0 || - slice_op.ellipsis_mask().getSExtValue() != 0 || - slice_op.end_mask().getSExtValue() != 0 || - slice_op.new_axis_mask().getSExtValue() != 0 || - slice_op.shrink_axis_mask().getSExtValue() != 1) - return {}; - - // Returns a value if the `value` is defined by a ConstOp with a single - // integer element in it and has an expected rank. - auto get_value = [](Value value, int expected_rank) -> Optional { - auto const_op = dyn_cast_or_null(value.getDefiningOp()); - if (!const_op) return None; - - auto value_attr = const_op.value().dyn_cast(); - if (!value_attr || value_attr.getNumElements() != 1) return None; - - auto value_ty = value_attr.getType(); - if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; - - auto splat = value_attr.getSplatValue(); - return splat.getValue().getSExtValue(); - }; - - // All other packed values are scalar constants. - SmallVector packed_dims; - packed_dims.reserve(pack_op.values().size() - 1); - for (Value operand : llvm::drop_begin(pack_op.values(), 1)) { - if (auto dim = get_value(operand, /*expected_rank=*/0)) { - packed_dims.push_back(*dim); - } else { - return {}; - } - } - - // Slice exactly the first shape dimension: - // begin = [0] end = [1], strides = [1] - auto begin = get_value(slice_op.begin(), /*expected_rank=*/1); - auto end = get_value(slice_op.end(), /*expected_rank=*/1); - auto strides = get_value(slice_op.strides(), /*expected_rank=*/1); - if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() || - *begin != 0 || *end != 1 || *strides != 1) - return {}; - - // First tensor dimension is dynamic. - auto arg_ty = tensor.getType().dyn_cast(); - if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 || - !arg_ty.isDynamicDim(0)) - return {}; - - // Argument tensor rank is equal to the number of packed dimensions. - if (arg_ty.getRank() != pack_op.values().size()) return {}; - - // All other dimensions are statically known and equal to packed dims. - auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); - if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin())) - return {}; - - return tensor; + return {}; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 17a19c50998..42659f41c21 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -377,6 +377,15 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> { // CHECK: return %1 : tensor<2x8xi32> } +// CHECK-LABEL: testReshapeToSelfShape +func @testReshapeToSelfShape(%arg0: tensor) -> tensor { + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> + %1 = "tf.Reshape"(%arg0, %0) : (tensor, tensor<2xi32>) -> tensor + + // CHECK: return %arg0 : tensor + return %1: tensor +} + // CHECK-LABEL: func @testReshapeNoOp func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> { %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32> @@ -385,8 +394,8 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x return %0 : tensor<2x4xf32> } -// CHECK-LABEL: func @testReshapeNoOpShapeComputation -func @testReshapeNoOpShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor<*xf32>) { +// CHECK-LABEL: func @testPackShapeComputation +func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. %d1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %d2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -396,65 +405,56 @@ func @testReshapeNoOpShapeComputation(%arg0: tensor, %arg1: tensor : tensor<1xi32>} : () -> tensor<1xi32> %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - // Fold reshape if the shape is computed from the input tensor: + // Fold pack operation if it computes the input tensor shape: // - // %shape = tf.Shape(%arg) // [? x ...] - // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value - // %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] - // %reshape = tf.Reshape(%arg, %new_shape) // this is no-op + // %shape = tf.Shape(%arg) // [? x ...] + // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value + // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] // - // Where `...` are some statically known dimensions. In this case reshape is - // a no-op and can be replaced by %arg (assuming `...` are equal). + // Where `...` are some statically known dimensions. In this case %pack can be + // replace with a %shape. This is a common pattern in models with a dynamic + // batch size. // Test Rank 2 + // CHECK: %[[SHAPE0:.*]] = "tf.Shape" %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> %4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor, tensor) -> tensor<2xi32> %6 = "tf.Reshape"(%arg0, %5) : (tensor, tensor<2xi32>) -> tensor // Test Rank 3. - + // CHECK: %[[SHAPE1:.*]] = "tf.Shape" %7 = "tf.Shape"(%arg1) : (tensor) -> tensor<3xi32> %8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> %10 = "tf.Reshape"(%arg1, %9) : (tensor, tensor<3xi32>) -> tensor - // Shape was taken from the op that is not reshaped in the end: - // Reshape(%arg1) vs Shape(%arg0) - %11 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %12 = "tf.Pack"(%11, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[RESHAPE0:.*]] = "tf.Reshape" - %13 = "tf.Reshape"(%arg1, %12) : (tensor, tensor<3xi32>) -> tensor - // Packed dimensions have different order from the reshape operand: // [?, 1, 2] vs [?, 2, 1] %14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[RESHAPE1:.*]] = "tf.Reshape" - %16 = "tf.Reshape"(%arg1, %15) : (tensor, tensor<3xi32>) -> tensor + // CHECK: %[[PACK0:.*]] = "tf.Pack" // StridedSlice takes second dimension from the shape: // begin = [1], end = [2], stride = [1] %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[RESHAPE2:.*]] = "tf.Reshape" - %19 = "tf.Reshape"(%arg1, %18) : (tensor, tensor<3xi32>) -> tensor + // CHECK: %[[PACK1:.*]] = "tf.Pack" // Packed dimensions have higher rank than the reshape operand: // [?, 1] vs [?, 1, 1] %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[RESHAPE3:.*]] = "tf.Reshape" - %22 = "tf.Reshape"(%arg0, %21) : (tensor, tensor<3xi32>) -> tensor + // CHECK: %[[PACK2:.*]] = "tf.Pack" // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> - %26 = "tf.Reshape"(%arg2, %25) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + // CHECK: %[[PACK3:.*]] = "tf.Pack" - // CHECK: return %arg0, %arg1, %[[RESHAPE0]], %[[RESHAPE1]], %[[RESHAPE2]], %[[RESHAPE3]] - return %6, %10, %13, %16, %19, %22, %26 : tensor, tensor, tensor, tensor, tensor, tensor, tensor<*xf32> + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]] + return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> } // CHECK-LABEL: testSelectScalarPred diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 3f0b5b48af9..d5b7eb7a739 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -209,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), (TF_ReshapeOp $arg, $shape)>; +def IsSame : Constraint>; +def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)), + (replaceWithValue $arg0), + [(IsSame $arg0, $arg1)]>; + //===----------------------------------------------------------------------===// // Select op patterns. //===----------------------------------------------------------------------===//