Add pattern to tflite to fuse relu1 with binary ops.
Also, add relu1 as fused activation function in the conv fuse pattern. PiperOrigin-RevId: 284293263 Change-Id: I1edbaf9a1e0e5d60fee3780970c9e5e5092b7b73
This commit is contained in:
parent
30d429ba4a
commit
92f61576fa
@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
|
||||
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fuse_relu_to_add
|
||||
func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
$multiplier)>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
|
||||
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||
// activation functions.
|
||||
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
|
||||
// those cannot be simply translated into clamping.
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6]] in
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in
|
||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||
|
||||
|
||||
@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||
TFL_MulOp, TFL_SubOp] in
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||
|
Loading…
Reference in New Issue
Block a user