[MLIR:TF] Replace RealDiv with constant divisor with multiplication by divisor reciprocal
PiperOrigin-RevId: 321301105 Change-Id: If0b55598897dbc4a26bba5eb8af552199f3f28ae
This commit is contained in:
parent
6e7992cae2
commit
0848e16fb1
@ -2774,7 +2774,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RealDivWithSqrtDivisor>(context);
|
||||
results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
|
||||
}
|
||||
|
||||
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
@ -586,6 +586,18 @@ func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRealDivWithConstDivisor
|
||||
func @testRealDivWithConstDivisor(%arg0: tensor<8x2xf32>) -> tensor<8x2xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[2.0, 4.0]> : tensor<2xf32>} : () -> tensor<2xf32>
|
||||
%1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32>
|
||||
return %1: tensor<8x2xf32>
|
||||
|
||||
// CHECK: %0 = "tf.Const"
|
||||
// CHECK-SAME: value = dense<[5.000000e-01, 2.500000e-01]
|
||||
// CHECK: %1 = "tf.Mul"(%arg0, %0)
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testTruncateDivWithSqrtDivisor
|
||||
func @testTruncateDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
|
@ -150,6 +150,7 @@ def LogToLog1p : Pat<
|
||||
// LogicalNot op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(ezhulenev): Generalize this pattern for all involutions.
|
||||
def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
@ -187,6 +188,13 @@ def NegNested : Pat<(TF_NegOp (TF_NegOp $arg)), (replaceWithValue $arg)>;
|
||||
def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
|
||||
(TF_MulOp $arg0, (TF_RsqrtOp $arg1))>;
|
||||
|
||||
// Replace division by a constant with a multiplication by a reciprocal of that
|
||||
// constant. Floating point division can be ~10x more expensive than a
|
||||
// multiplication.
|
||||
def RealDivWithConstDivisor : Pat<
|
||||
(TF_RealDivOp $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)),
|
||||
(TF_MulOp $arg0, (TF_ReciprocalOp (TF_ConstOp $value)))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reciprocal op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user