Adds pattern for LeakyRelu in TFLite
PiperOrigin-RevId: 291406515 Change-Id: I2b155c82109e0e3db2f9f4c09ec8357e6777b972
This commit is contained in:
parent
2673c7c0a4
commit
0da59c63fa
@ -690,6 +690,27 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: leaky_relu_fusion
|
||||
func @leaky_relu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<0.2> : tensor<f32>
|
||||
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.leaky_relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: leaky_relu_not_fused
|
||||
// Should not fuse to LeakyRelu, since alpha > 1.
|
||||
func @leaky_relu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<1.2> : tensor<f32>
|
||||
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.maximum"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
|
||||
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
|
||||
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
|
||||
@ -23,6 +23,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
def ExtractSingleElementAsFloat : NativeCodeCall<
|
||||
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -358,6 +361,7 @@ class ValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
|
||||
// ReLU patterns
|
||||
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
@ -370,6 +374,14 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||
$input2),
|
||||
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
|
||||
[(ConstDoubleValueLessThan<"1"> $alpha),
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user