[MLIR:TF] Replace RealDiv with constant divisor with multiplication by divisor reciprocal

PiperOrigin-RevId: 321301105
Change-Id: If0b55598897dbc4a26bba5eb8af552199f3f28ae
This commit is contained in:
Eugene Zhulenev 2020-07-14 22:40:02 -07:00 committed by TensorFlower Gardener
parent 6e7992cae2
commit 0848e16fb1
3 changed files with 21 additions and 1 deletions

View File

@ -2774,7 +2774,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RealDivWithSqrtDivisor>(context);
results.insert<RealDivWithSqrtDivisor, RealDivWithConstDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {

View File

@ -586,6 +586,18 @@ func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32
// CHECK: return %1
}
// CHECK-LABEL: testRealDivWithConstDivisor
func @testRealDivWithConstDivisor(%arg0: tensor<8x2xf32>) -> tensor<8x2xf32> {
%0 = "tf.Const"() {value = dense<[2.0, 4.0]> : tensor<2xf32>} : () -> tensor<2xf32>
%1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32>
return %1: tensor<8x2xf32>
// CHECK: %0 = "tf.Const"
// CHECK-SAME: value = dense<[5.000000e-01, 2.500000e-01]
// CHECK: %1 = "tf.Mul"(%arg0, %0)
// CHECK: return %1
}
// CHECK-LABEL: testTruncateDivWithSqrtDivisor
func @testTruncateDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32>

View File

@ -150,6 +150,7 @@ def LogToLog1p : Pat<
// LogicalNot op patterns.
//===----------------------------------------------------------------------===//
// TODO(ezhulenev): Generalize this pattern for all involutions.
def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)),
(replaceWithValue $arg)>;
@ -187,6 +188,13 @@ def NegNested : Pat<(TF_NegOp (TF_NegOp $arg)), (replaceWithValue $arg)>;
def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
(TF_MulOp $arg0, (TF_RsqrtOp $arg1))>;
// Replace division by a constant with a multiplication by a reciprocal of that
// constant. Floating point division can be ~10x more expensive than a
// multiplication.
def RealDivWithConstDivisor : Pat<
(TF_RealDivOp $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)),
(TF_MulOp $arg0, (TF_ReciprocalOp (TF_ConstOp $value)))>;
//===----------------------------------------------------------------------===//
// Reciprocal op patterns.
//===----------------------------------------------------------------------===//