Optimize trivial RealDiv ops
PiperOrigin-RevId: 311874492 Change-Id: I8084b4a0a913d4585420bff20a21688ae8d41286
This commit is contained in:
parent
fd976b2def
commit
d70dc548b5
@ -6331,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -110,7 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
|
||||
return !type || type.getRank() <= rank;
|
||||
}
|
||||
|
||||
|
||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
}
|
||||
@ -462,9 +461,10 @@ LogicalResult FoldOperandsPermutation(
|
||||
namespace {
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
template <typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
|
||||
template <
|
||||
typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
|
||||
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
@ -479,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||
// value zero.
|
||||
int identity =
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
|
||||
std::is_same<OpT, RealDivOp>::value);
|
||||
|
||||
Type element_ty = lhs_type.getElementType();
|
||||
Attribute identity_attr;
|
||||
@ -2408,6 +2409,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<RealDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -384,6 +384,15 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialRealDiv
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xbf16>
|
||||
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||
|
Loading…
Reference in New Issue
Block a user