diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 82282bb925a..d53bafff638 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -6331,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 2007824369c..78623ca3c61 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -110,7 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) { return !type || type.getRank() <= rank; } - static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } @@ -462,9 +461,10 @@ LogicalResult FoldOperandsPermutation( namespace { // Folder that returns LHS of an Arithmetic Op if the RHS is a constant // known to be Identity (e.g X+0) -template ::value>::type * = nullptr> +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { auto result_op_type = arithmetic_op.getResult().getType(); @@ -479,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, // Mul and Div ops have identity value one while AddV2 and SubOp have identity // value zero. int identity = - (std::is_same::value || std::is_same::value); + (std::is_same::value || std::is_same::value || + std::is_same::value); Type element_ty = lhs_type.getElementType(); Attribute identity_attr; @@ -2408,6 +2409,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult RealDivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index bccb8923134..32815956ff7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -384,6 +384,15 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> } +func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialRealDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { %cst = constant dense<1.0> : tensor<2x2xbf16> %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>