Do not apply optimizations when they will create 5-D broadcast operators
TFLite broadcast-able operations are able to handle up to 4-D broadcast inputs safely. PiperOrigin-RevId: 342803141 Change-Id: Ib7035f5d6a5dbcc918431ae31cfda026b3c3f189
This commit is contained in:
parent
b437c554e8
commit
b858de3779
@ -578,6 +578,32 @@ func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) ->
|
|||||||
// CHECK: return %[[rs2]]
|
// CHECK: return %[[rs2]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @NotReorderReshapeAddIf5DInputs
|
||||||
|
func @NotReorderReshapeAddIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||||
|
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
|
||||||
|
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
return %2 : tensor<1x1x1x1x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @NotReorderReshapeFloorDivIf5DInputs
|
||||||
|
func @NotReorderReshapeFloorDivIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||||
|
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
|
||||||
|
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||||
|
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
%2 = "tfl.floor_div"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
return %2 : tensor<1x1x1x1x2xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||||
|
// CHECK: %[[rs2:.*]] = tfl.floor_div %[[rs1]]
|
||||||
|
// CHECK: return %[[rs2]]
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
|
||||||
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||||
%cst = constant dense<2.0> : tensor<1x40xf32>
|
%cst = constant dense<2.0> : tensor<1x40xf32>
|
||||||
@ -896,6 +922,36 @@ func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32>) ->
|
|||||||
// CHECK: return %[[RES]]
|
// CHECK: return %[[RES]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: notFuseTileWithBinaryOpOn5DInputs
|
||||||
|
func @notFuseTileWithBinaryOpOn5DInputs(%arg0: tensor<1x1xf32>) -> tensor<1x1x1x1x2xf32> {
|
||||||
|
%cst = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
|
||||||
|
%cst1 = constant dense<3.0> : tensor<1x1x1x1x2xf32>
|
||||||
|
%0 = "tfl.sqrt"(%arg0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||||
|
%1 = "tfl.tile"(%0, %cst) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
%2 = "tfl.add"(%cst1, %1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
|
||||||
|
return %2 : tensor<1x1x1x1x2xf32>
|
||||||
|
|
||||||
|
// CHECK: "tfl.sqrt"
|
||||||
|
// CHECK: "tfl.tile"
|
||||||
|
// CHECK: tfl.add
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: notFuseTileWithBinaryOp1On5DInputs
|
||||||
|
func @notFuseTileWithBinaryOp1On5DInputs(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> {
|
||||||
|
%cst_0 = constant dense<1.0> : tensor<f32>
|
||||||
|
%cst_1 = constant dense<[1, 1, 1, 1, 128]> : tensor<5xi32>
|
||||||
|
%0 = "tfl.add"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<f32>) -> tensor<1x1xf32>
|
||||||
|
%1 = "tfl.sqrt"(%0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||||
|
%2 = "tfl.tile"(%1, %cst_1) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x128xf32>
|
||||||
|
%3 = "tfl.div"(%2, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x128xf32>, tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32>
|
||||||
|
return %3 : tensor<1x1x1x1x128xf32>
|
||||||
|
|
||||||
|
// CHECK: "tfl.add"
|
||||||
|
// CHECK: "tfl.sqrt"
|
||||||
|
// CHECK: "tfl.tile"
|
||||||
|
// CHECK: tfl.div
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: InvalidFuseTileWithBinaryOp
|
// CHECK-LABEL: InvalidFuseTileWithBinaryOp
|
||||||
func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> {
|
func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> {
|
||||||
%cst = constant dense<[[1,2]]> : tensor<1x2xi32>
|
%cst = constant dense<[[1,2]]> : tensor<1x2xi32>
|
||||||
@ -1155,6 +1211,18 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||||||
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
|
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @NotReorderAddWithConstantOn5D(%arg0: tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> {
|
||||||
|
%cst = constant dense<1.0> : tensor<2x2x2x2x2xf32>
|
||||||
|
%cst_1 = constant dense<2.0> : tensor<2x2x2x2x2xf32>
|
||||||
|
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
|
||||||
|
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
|
||||||
|
return %1 : tensor<2x2x2x2x2xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: NotReorderAddWithConstantOn5D
|
||||||
|
// CHECK: tfl.add
|
||||||
|
// CHECK: tfl.add
|
||||||
|
}
|
||||||
|
|
||||||
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
return %1 : tensor<2x2xf32>
|
return %1 : tensor<2x2xf32>
|
||||||
@ -1427,3 +1495,20 @@ func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x
|
|||||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||||
// CHECK: return %[[RES]]
|
// CHECK: return %[[RES]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @AvoidFuseFullyConnectedAddWithSplat2D
|
||||||
|
func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x1x1x1x1xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%cst2 = constant dense<2.0> : tensor<1x1x1x1x1xf32>
|
||||||
|
|
||||||
|
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
|
||||||
|
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x1xf32>, tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32>
|
||||||
|
|
||||||
|
return %1 : tensor<1x1x1x1x1xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CST1:.*]] = constant unit
|
||||||
|
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32>
|
||||||
|
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
|
||||||
|
// CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32>
|
||||||
|
// CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32>
|
||||||
|
}
|
||||||
|
@ -778,7 +778,8 @@ struct ScalarizeSplatConstantForBroadcastableOps
|
|||||||
// cannot scalarize the splat constant because the result shape relies on
|
// cannot scalarize the splat constant because the result shape relies on
|
||||||
// the splat constant op's shape for broadcasting.
|
// the splat constant op's shape for broadcasting.
|
||||||
if (!non_splat_operand_type.hasStaticShape() ||
|
if (!non_splat_operand_type.hasStaticShape() ||
|
||||||
non_splat_operand_type.getShape() != result_type.getShape()) {
|
non_splat_operand_type.getShape() != result_type.getShape() ||
|
||||||
|
non_splat_operand_type.getRank() > 4) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -376,13 +376,17 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
|||||||
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
|
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
|
||||||
$operand, $act_func),
|
$operand, $act_func),
|
||||||
(BinaryOp $input, $operand, $act_func),
|
(BinaryOp $input, $operand, $act_func),
|
||||||
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
|
[(OperandsBroadcastToOutputType $input, $operand, $result),
|
||||||
|
(HasRankAtMost<4> $input),
|
||||||
|
(HasRankAtMost<4> $operand)]>;
|
||||||
|
|
||||||
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
|
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
|
||||||
(BinaryOp:$result $operand,
|
(BinaryOp:$result $operand,
|
||||||
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
|
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
|
||||||
(BinaryOp $operand, $input, $act_func),
|
(BinaryOp $operand, $input, $act_func),
|
||||||
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
|
[(OperandsBroadcastToOutputType $operand, $input, $result),
|
||||||
|
(HasRankAtMost<4> $operand),
|
||||||
|
(HasRankAtMost<4> $input)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||||
@ -427,8 +431,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
|||||||
// `input`. In other words, the shape of the `Reshape` op are not
|
// `input`. In other words, the shape of the `Reshape` op are not
|
||||||
// changed after the transformation.
|
// changed after the transformation.
|
||||||
(IsTailOfShape $rhs, $input),
|
(IsTailOfShape $rhs, $input),
|
||||||
(HasRankAtMost<5> $input),
|
(HasRankAtMost<4> $input),
|
||||||
(HasRankAtMost<5> $rhs)]>;
|
(HasRankAtMost<4> $lhs),
|
||||||
|
(HasRankAtMost<4> $rhs)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||||
@ -457,7 +462,10 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
|||||||
// The result of the new "BinaryOp" will have the same shape as
|
// The result of the new "BinaryOp" will have the same shape as
|
||||||
// `input`. In other words, the shape of the `Reshape` op are not
|
// `input`. In other words, the shape of the `Reshape` op are not
|
||||||
// changed after the transformation.
|
// changed after the transformation.
|
||||||
(IsTailOfShape $rhs, $input)]>;
|
(IsTailOfShape $rhs, $input),
|
||||||
|
(HasRankAtMost<4> $input),
|
||||||
|
(HasRankAtMost<4> $lhs),
|
||||||
|
(HasRankAtMost<4> $rhs)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reorder the element-wise value operations and the element move operations,
|
// Reorder the element-wise value operations and the element move operations,
|
||||||
@ -568,7 +576,10 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
|
|||||||
(TFL_AddOp $input,
|
(TFL_AddOp $input,
|
||||||
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
||||||
ActFun),
|
ActFun),
|
||||||
[(HasOneUse $first_output)]>;
|
[(HasOneUse $first_output),
|
||||||
|
(HasRankAtMost<4> $input),
|
||||||
|
(HasRankAtMost<4> $a),
|
||||||
|
(HasRankAtMost<4> $b)]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
|
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
|
||||||
|
Loading…
Reference in New Issue
Block a user