diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index cfd3b61b2c3..98bc6b3089a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2774,7 +2774,7 @@ OpFoldResult RankOp::fold(ArrayRef operands) { void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } OpFoldResult RealDivOp::fold(ArrayRef operands) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index d61fc66a5e6..514db1f4f08 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -586,6 +586,18 @@ func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32 // CHECK: return %1 } +// CHECK-LABEL: testRealDivWithConstDivisor +func @testRealDivWithConstDivisor(%arg0: tensor<8x2xf32>) -> tensor<8x2xf32> { + %0 = "tf.Const"() {value = dense<[2.0, 4.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32> + return %1: tensor<8x2xf32> + + // CHECK: %0 = "tf.Const" + // CHECK-SAME: value = dense<[5.000000e-01, 2.500000e-01] + // CHECK: %1 = "tf.Mul"(%arg0, %0) + // CHECK: return %1 +} + // CHECK-LABEL: testTruncateDivWithSqrtDivisor func @testTruncateDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 9d72284da91..3f0b5b48af9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -150,6 +150,7 @@ def LogToLog1p : Pat< // LogicalNot op patterns. //===----------------------------------------------------------------------===// +// TODO(ezhulenev): Generalize this pattern for all involutions. def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)), (replaceWithValue $arg)>; @@ -187,6 +188,13 @@ def NegNested : Pat<(TF_NegOp (TF_NegOp $arg)), (replaceWithValue $arg)>; def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; +// Replace division by a constant with a multiplication by a reciprocal of that +// constant. Floating point division can be ~10x more expensive than a +// multiplication. +def RealDivWithConstDivisor : Pat< + (TF_RealDivOp $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)), + (TF_MulOp $arg0, (TF_ReciprocalOp (TF_ConstOp $value)))>; + //===----------------------------------------------------------------------===// // Reciprocal op patterns. //===----------------------------------------------------------------------===//