Optimize tfl.maximum($x, 0) to tfl.relu($x) to enable further optimizations

PiperOrigin-RevId: 324166099
Change-Id: Ic3308361eb916fd0eccd928392000609c7b1c6bc
This commit is contained in:
A. Unique TensorFlower 2020-07-31 00:48:13 -07:00 committed by TensorFlower Gardener
parent 6be7a14b65
commit 98400b759b
2 changed files with 14 additions and 0 deletions

View File

@ -855,6 +855,15 @@ func @doNotConvertNonTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) ->
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu
func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense<0.0> : tensor<f32>
%0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
// CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0)
// CHECK: return %[[RESULT]]
}
// CHECK-LABEL: Relu1
func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {

View File

@ -425,6 +425,11 @@ class FloatValueEquals<string val> : Constraint<CPred<
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
// ReLU patterns
def MatchReluPattern : Pat<
(TFL_MaximumOp $input, (ConstantOp $Zero)),
(TFL_ReluOp $input),
[(FloatValueEquals<"0"> $Zero)]>;
def MatchRelu1Pattern1 : Pat<
(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
(ConstantOp $One)),