Add pattern to tflite to fuse relu1 with binary ops.

Also, add relu1 as fused activation function in the conv fuse pattern.

PiperOrigin-RevId: 284293263
Change-Id: I1edbaf9a1e0e5d60fee3780970c9e5e5092b7b73
This commit is contained in:
Karim Nosir 2019-12-06 17:23:10 -08:00 committed by TensorFlower Gardener
parent 30d429ba4a
commit 92f61576fa
2 changed files with 27 additions and 2 deletions

View File

@ -622,3 +622,12 @@ func @Relu1_2(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[relu_n1_to_1:[0-9].*]] = "tfl.relu_n1_to_1"
}
// CHECK-LABEL: fuse_relu_to_add
func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3xf32> {
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.relu_n1_to_1"(%0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
// CHECK: return %[[RES]]
}

View File

@ -44,12 +44,13 @@ multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
$multiplier)>;
}
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
// activation functions.
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
// those cannot be simply translated into clamping.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6]] in
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
@ -291,3 +292,18 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
}
}
// Instantiated FusedBinary patterns for the from-to pairs of ops.
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
TFL_MulOp, TFL_SubOp] in
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;