Add folder to mhlo::round_nearest_afz
PiperOrigin-RevId: 337823786 Change-Id: Ibd23058c6a2287f1de859fcfa00eb63ed7bfc3f4
This commit is contained in:
parent
a477bd308c
commit
2ff96dca28
@ -241,7 +241,9 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||
}
|
||||
|
||||
def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
|
||||
|
||||
@ -1933,6 +1933,14 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
struct round {
|
||||
APFloat operator()(const APFloat& f) {
|
||||
APFloat r = f;
|
||||
r.roundToIntegral(llvm::RoundingMode::NearestTiesToAway);
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
#define UNARY_FOLDER(Op, Func) \
|
||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||
@ -1942,7 +1950,15 @@ static Attribute UnaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||
return {}; \
|
||||
}
|
||||
|
||||
#define UNARY_FOLDER_FLOAT(Op, Func) \
|
||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||
return UnaryFolder<Op, FloatType, APFloat, Func>(this, attrs); \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
UNARY_FOLDER(NegOp, std::negate);
|
||||
UNARY_FOLDER_FLOAT(RoundOp, round);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOps
|
||||
|
||||
@ -81,6 +81,14 @@ func @remainder_fold_float() -> tensor<4xf32> {
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: round_fold
|
||||
func @round_fold() -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32>
|
||||
%1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user