From b858de3779cab56d9788b0e87ff437d5dffbd942 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Mon, 16 Nov 2020 23:46:32 -0800 Subject: [PATCH] Do not apply optimizations when they will create 5-D broadcast operators TFLite broadcast-able operations are able to handle up to 4-D broadcast inputs safely. PiperOrigin-RevId: 342803141 Change-Id: Ib7035f5d6a5dbcc918431ae31cfda026b3c3f189 --- .../compiler/mlir/lite/tests/optimize.mlir | 85 +++++++++++++++++++ .../compiler/mlir/lite/transforms/optimize.cc | 3 +- .../mlir/lite/transforms/optimize_patterns.td | 23 +++-- 3 files changed, 104 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index e7516658b07..52021dccbfe 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -578,6 +578,32 @@ func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) -> // CHECK: return %[[rs2]] } +// CHECK-LABEL: @NotReorderReshapeAddIf5DInputs +func @NotReorderReshapeAddIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> { + %cst = constant dense<2.0> : tensor<1x1x1x1x2xf32> + %shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32> + %2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32> + return %2 : tensor<1x1x1x1x2xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = tfl.add %[[rs1]] + // CHECK: return %[[rs2]] +} + +// CHECK-LABEL: @NotReorderReshapeFloorDivIf5DInputs +func @NotReorderReshapeFloorDivIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> { + %cst = constant dense<2.0> : tensor<1x1x1x1x2xf32> + %shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32> + %2 = "tfl.floor_div"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32> + return %2 : tensor<1x1x1x1x2xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = tfl.floor_div %[[rs1]] + // CHECK: return %[[rs2]] +} + // CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { %cst = constant dense<2.0> : tensor<1x40xf32> @@ -896,6 +922,36 @@ func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32>) -> // CHECK: return %[[RES]] } +// CHECK-LABEL: notFuseTileWithBinaryOpOn5DInputs +func @notFuseTileWithBinaryOpOn5DInputs(%arg0: tensor<1x1xf32>) -> tensor<1x1x1x1x2xf32> { + %cst = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32> + %cst1 = constant dense<3.0> : tensor<1x1x1x1x2xf32> + %0 = "tfl.sqrt"(%arg0) : (tensor<1x1xf32>) -> tensor<1x1xf32> + %1 = "tfl.tile"(%0, %cst) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32> + %2 = "tfl.add"(%cst1, %1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32> + return %2 : tensor<1x1x1x1x2xf32> + + // CHECK: "tfl.sqrt" + // CHECK: "tfl.tile" + // CHECK: tfl.add +} + +// CHECK-LABEL: notFuseTileWithBinaryOp1On5DInputs +func @notFuseTileWithBinaryOp1On5DInputs(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> { + %cst_0 = constant dense<1.0> : tensor + %cst_1 = constant dense<[1, 1, 1, 1, 128]> : tensor<5xi32> + %0 = "tfl.add"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor) -> tensor<1x1xf32> + %1 = "tfl.sqrt"(%0) : (tensor<1x1xf32>) -> tensor<1x1xf32> + %2 = "tfl.tile"(%1, %cst_1) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x128xf32> + %3 = "tfl.div"(%2, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x128xf32>, tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> + return %3 : tensor<1x1x1x1x128xf32> + + // CHECK: "tfl.add" + // CHECK: "tfl.sqrt" + // CHECK: "tfl.tile" + // CHECK: tfl.div +} + // CHECK-LABEL: InvalidFuseTileWithBinaryOp func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> { %cst = constant dense<[[1,2]]> : tensor<1x2xi32> @@ -1155,6 +1211,18 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32> } +func @NotReorderAddWithConstantOn5D(%arg0: tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2x2x2x2xf32> + %cst_1 = constant dense<2.0> : tensor<2x2x2x2x2xf32> + %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> + %1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> + return %1 : tensor<2x2x2x2x2xf32> + + // CHECK-LABEL: NotReorderAddWithConstantOn5D + // CHECK: tfl.add + // CHECK: tfl.add +} + func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %1 : tensor<2x2xf32> @@ -1427,3 +1495,20 @@ func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x // CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32> // CHECK: return %[[RES]] } + +// CHECK-LABEL: @AvoidFuseFullyConnectedAddWithSplat2D +func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x1x1x1x1xf32> { + %cst = constant unit + %cst2 = constant dense<2.0> : tensor<1x1x1x1x1xf32> + + %0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32> + %1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x1xf32>, tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32> + + return %1 : tensor<1x1x1x1x1xf32> + + // CHECK: %[[CST1:.*]] = constant unit + // CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32> + // CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32> + // CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32> + // CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index a72537f4963..c371c4f6e56 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -778,7 +778,8 @@ struct ScalarizeSplatConstantForBroadcastableOps // cannot scalarize the splat constant because the result shape relies on // the splat constant op's shape for broadcasting. if (!non_splat_operand_type.hasStaticShape() || - non_splat_operand_type.getShape() != result_type.getShape()) { + non_splat_operand_type.getShape() != result_type.getShape() || + non_splat_operand_type.getRank() > 4) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 7eeffa26bd5..6c47995f685 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -376,13 +376,17 @@ multiclass FuseTileBroadcastIntoFollowingBinary { (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), $operand, $act_func), (BinaryOp $input, $operand, $act_func), - [(OperandsBroadcastToOutputType $input, $operand, $result)]>; + [(OperandsBroadcastToOutputType $input, $operand, $result), + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $operand)]>; def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat< (BinaryOp:$result $operand, (TFL_TileOp $input, (ConstantOp $tile)), $act_func), (BinaryOp $operand, $input, $act_func), - [(OperandsBroadcastToOutputType $operand, $input, $result)]>; + [(OperandsBroadcastToOutputType $operand, $input, $result), + (HasRankAtMost<4> $operand), + (HasRankAtMost<4> $input)]>; } // Multi-pattern consisting of matching stand-alone op or op followed by relu. @@ -427,8 +431,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { // `input`. In other words, the shape of the `Reshape` op are not // changed after the transformation. (IsTailOfShape $rhs, $input), - (HasRankAtMost<5> $input), - (HasRankAtMost<5> $rhs)]>; + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs), + (HasRankAtMost<4> $rhs)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, @@ -457,7 +462,10 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, // The result of the new "BinaryOp" will have the same shape as // `input`. In other words, the shape of the `Reshape` op are not // changed after the transformation. - (IsTailOfShape $rhs, $input)]>; + (IsTailOfShape $rhs, $input), + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $lhs), + (HasRankAtMost<4> $rhs)]>; } // Reorder the element-wise value operations and the element move operations, @@ -568,7 +576,10 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { (TFL_AddOp $input, (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), ActFun), - [(HasOneUse $first_output)]>; + [(HasOneUse $first_output), + (HasRankAtMost<4> $input), + (HasRankAtMost<4> $a), + (HasRankAtMost<4> $b)]>; } // We can eliminate Relu from Relu(SquaredDifference(x, y)),