Add division, min and max constant folding
Note: For integers, these operations are performed in their signed versions. PiperOrigin-RevId: 313918602 Change-Id: Icc1b68219351985f537ad69b45f7328f65d5d223
This commit is contained in:
parent
64376b5432
commit
e9077ee047
@ -1474,6 +1474,38 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct divide : std::divides<T> {};
|
||||
|
||||
template <>
|
||||
struct divide<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct max {
|
||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct max<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const {
|
||||
return llvm::APIntOps::smax(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct min {
|
||||
T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct min<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const {
|
||||
return llvm::APIntOps::smin(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
#define BINARY_FOLDER(Op, Func) \
|
||||
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
|
||||
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
|
||||
@ -1483,9 +1515,16 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
|
||||
return {}; \
|
||||
}
|
||||
|
||||
// Addition, subtraction and multiplication use the std:: versions of the ops.
|
||||
// Due to the other ops behaving differently in signed vs unsigned integers,
|
||||
// APInts need a special implementation. Currently, it replicates signed int
|
||||
// op behavior.
|
||||
BINARY_FOLDER(AddOp, std::plus);
|
||||
BINARY_FOLDER(SubOp, std::minus);
|
||||
BINARY_FOLDER(MulOp, std::multiplies);
|
||||
BINARY_FOLDER(DivOp, divide);
|
||||
BINARY_FOLDER(MaxOp, max);
|
||||
BINARY_FOLDER(MinOp, min);
|
||||
|
||||
#undef BINARY_FOLDER
|
||||
|
||||
|
@ -303,14 +303,17 @@ def HLO_ComplexOp: HLO_Op<"complex",
|
||||
|
||||
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum",
|
||||
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply",
|
||||
|
@ -45,6 +45,60 @@ func @multiply_scalar_fold() -> tensor<4xi64> {
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_scalar_fold
|
||||
func @divide_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<1>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_fold_float
|
||||
func @divide_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<7>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_fold_float
|
||||
func @max_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_scalar_fold
|
||||
func @min_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<-5>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_fold_float
|
||||
func @min_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_noop
|
||||
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user