Do not apply optimizations when they will create 5-D broadcast operators

TFLite broadcast-able operations are able to handle up to 4-D broadcast inputs safely.

PiperOrigin-RevId: 342803141
Change-Id: Ib7035f5d6a5dbcc918431ae31cfda026b3c3f189
This commit is contained in:
Jaesung Chung 2020-11-16 23:46:32 -08:00 committed by TensorFlower Gardener
parent b437c554e8
commit b858de3779
3 changed files with 104 additions and 7 deletions

View File

@ -578,6 +578,32 @@ func @NotReorderReshapeAddIfNotTailingDimAfter(%arg0: tensor<1x30x1x96xf32>) ->
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIf5DInputs
func @NotReorderReshapeAddIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeFloorDivIf5DInputs
func @NotReorderReshapeFloorDivIf5DInputs(%arg0: tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<2.0> : tensor<1x1x1x1x2xf32>
%shape = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.floor_div"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = tfl.floor_div %[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotTailingDim
func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<1x40xf32>
@ -896,6 +922,36 @@ func @fuseTileWithBinaryOp1(%arg0: tensor<1x1xf32>, %arg1: tensor<1x128xf32>) ->
// CHECK: return %[[RES]]
}
// CHECK-LABEL: notFuseTileWithBinaryOpOn5DInputs
func @notFuseTileWithBinaryOpOn5DInputs(%arg0: tensor<1x1xf32>) -> tensor<1x1x1x1x2xf32> {
%cst = constant dense<[1, 1, 1, 1, 2]> : tensor<5xi32>
%cst1 = constant dense<3.0> : tensor<1x1x1x1x2xf32>
%0 = "tfl.sqrt"(%arg0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
%1 = "tfl.tile"(%0, %cst) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x2xf32>
%2 = "tfl.add"(%cst1, %1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x2xf32>, tensor<1x1x1x1x2xf32>) -> tensor<1x1x1x1x2xf32>
return %2 : tensor<1x1x1x1x2xf32>
// CHECK: "tfl.sqrt"
// CHECK: "tfl.tile"
// CHECK: tfl.add
}
// CHECK-LABEL: notFuseTileWithBinaryOp1On5DInputs
func @notFuseTileWithBinaryOp1On5DInputs(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32> {
%cst_0 = constant dense<1.0> : tensor<f32>
%cst_1 = constant dense<[1, 1, 1, 1, 128]> : tensor<5xi32>
%0 = "tfl.add"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<f32>) -> tensor<1x1xf32>
%1 = "tfl.sqrt"(%0) : (tensor<1x1xf32>) -> tensor<1x1xf32>
%2 = "tfl.tile"(%1, %cst_1) : (tensor<1x1xf32>, tensor<5xi32>) -> tensor<1x1x1x1x128xf32>
%3 = "tfl.div"(%2, %arg1) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x128xf32>, tensor<1x1x1x1x128xf32>) -> tensor<1x1x1x1x128xf32>
return %3 : tensor<1x1x1x1x128xf32>
// CHECK: "tfl.add"
// CHECK: "tfl.sqrt"
// CHECK: "tfl.tile"
// CHECK: tfl.div
}
// CHECK-LABEL: InvalidFuseTileWithBinaryOp
func @InvalidFuseTileWithBinaryOp(%arg0: tensor<2x3xf32>) -> tensor<2x6xf32> {
%cst = constant dense<[[1,2]]> : tensor<1x2xi32>
@ -1155,6 +1211,18 @@ func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
}
func @NotReorderAddWithConstantOn5D(%arg0: tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2x2x2x2xf32>
%cst_1 = constant dense<2.0> : tensor<2x2x2x2x2xf32>
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2x2x2x2xf32>, tensor<2x2x2x2x2xf32>) -> tensor<2x2x2x2x2xf32>
return %1 : tensor<2x2x2x2x2xf32>
// CHECK-LABEL: NotReorderAddWithConstantOn5D
// CHECK: tfl.add
// CHECK: tfl.add
}
func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%1 = "tfl.cast"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
@ -1427,3 +1495,20 @@ func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @AvoidFuseFullyConnectedAddWithSplat2D
func @AvoidFuseFullyConnectedAddWithSplat2D(%arg0: tensor<1x1x1x1x1xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x1x1x1x1xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<1x1x1x1x1xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x1x1x1x1xf32>, tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32>
return %1 : tensor<1x1x1x1x1xf32>
// CHECK: %[[CST1:.*]] = constant unit
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1x1x1x1x1xf32>
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[CST1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x1x1x1xf32>, tensor<1x1xf32>, none) -> tensor<1x1x1x1x1xf32>
// CHECK: %[[ADD:.*]] = tfl.add %[[FC_RESULT]], %[[CST2]] {fused_activation_function = "NONE"} : tensor<1x1x1x1x1xf32>
// CHECK: return %[[ADD]] : tensor<1x1x1x1x1xf32>
}

View File

@ -778,7 +778,8 @@ struct ScalarizeSplatConstantForBroadcastableOps
// cannot scalarize the splat constant because the result shape relies on
// the splat constant op's shape for broadcasting.
if (!non_splat_operand_type.hasStaticShape() ||
non_splat_operand_type.getShape() != result_type.getShape()) {
non_splat_operand_type.getShape() != result_type.getShape() ||
non_splat_operand_type.getRank() > 4) {
return failure();
}

View File

@ -376,13 +376,17 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
[(OperandsBroadcastToOutputType $input, $operand, $result)]>;
[(OperandsBroadcastToOutputType $input, $operand, $result),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $operand)]>;
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
(BinaryOp:$result $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
[(OperandsBroadcastToOutputType $operand, $input, $result)]>;
[(OperandsBroadcastToOutputType $operand, $input, $result),
(HasRankAtMost<4> $operand),
(HasRankAtMost<4> $input)]>;
}
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
@ -427,8 +431,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<5> $input),
(HasRankAtMost<5> $rhs)]>;
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
@ -457,7 +462,10 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
(IsTailOfShape $rhs, $input),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
}
// Reorder the element-wise value operations and the element move operations,
@ -568,7 +576,10 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output)]>;
[(HasOneUse $first_output),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $a),
(HasRankAtMost<4> $b)]>;
}
// We can eliminate Relu from Relu(SquaredDifference(x, y)),