Optimize tfl.maximum($x, 0) to tfl.relu($x) to enable further optimizations
PiperOrigin-RevId: 324166099 Change-Id: Ic3308361eb916fd0eccd928392000609c7b1c6bc
This commit is contained in:
parent
6be7a14b65
commit
98400b759b
@ -855,6 +855,15 @@ func @doNotConvertNonTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) ->
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu
|
||||
func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%cst = constant dense<0.0> : tensor<f32>
|
||||
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
return %0 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0)
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Relu1
|
||||
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
|
@ -425,6 +425,11 @@ class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
|
||||
// ReLU patterns
|
||||
def MatchReluPattern : Pat<
|
||||
(TFL_MaximumOp $input, (ConstantOp $Zero)),
|
||||
(TFL_ReluOp $input),
|
||||
[(FloatValueEquals<"0"> $Zero)]>;
|
||||
|
||||
def MatchRelu1Pattern1 : Pat<
|
||||
(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user