Add pattern to tflite to fuse relu1 with binary ops.
Also, add relu1 as fused activation function in the conv fuse pattern. PiperOrigin-RevId: 284293263 Change-Id: I1edbaf9a1e0e5d60fee3780970c9e5e5092b7b73
This commit is contained in:
parent
30d429ba4a
commit
92f61576fa
@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
|||||||
|
|
||||||
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: fuse_relu_to_add
|
||||||
|
func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||||
|
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||||
|
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||||
|
return %1 : tensor<2x3xf32>
|
||||||
|
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
}
|
||||||
|
@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
|||||||
$multiplier)>;
|
$multiplier)>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
|
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||||
// activation functions.
|
// activation functions.
|
||||||
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
|
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
|
||||||
// those cannot be simply translated into clamping.
|
// those cannot be simply translated into clamping.
|
||||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||||
[TFL_Relu6Op, TFL_AF_Relu6]] in
|
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||||
|
[TFL_Relu1Op, TFL_AF_Relu1]] in
|
||||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||||
|
|
||||||
|
|
||||||
@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
|||||||
(ConstantOp $NegOne)),
|
(ConstantOp $NegOne)),
|
||||||
(TFL_Relu1Op $input),
|
(TFL_Relu1Op $input),
|
||||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||||
|
|
||||||
|
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||||
|
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||||
|
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||||
|
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||||
|
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||||
|
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||||
|
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||||
|
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||||
|
TFL_MulOp, TFL_SubOp] in
|
||||||
|
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||||
|
Loading…
Reference in New Issue
Block a user