Optimize trivial RealDiv ops

PiperOrigin-RevId: 311874492
Change-Id: I8084b4a0a913d4585420bff20a21688ae8d41286
This commit is contained in:
Karim Nosir 2020-05-16 03:51:08 -07:00 committed by TensorFlower Gardener
parent fd976b2def
commit d70dc548b5
3 changed files with 21 additions and 5 deletions

View File

@ -6331,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {

View File

@ -110,7 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
return !type || type.getRank() <= rank;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1;
}
@ -462,9 +461,10 @@ LogicalResult FoldOperandsPermutation(
namespace {
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
template <typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
template <
typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType();
@ -479,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value);
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
@ -2408,6 +2409,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<RealDivWithSqrtDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

View File

@ -384,6 +384,15 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialRealDiv
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%cst = constant dense<1.0> : tensor<2x2xbf16>
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>