Add division, min and max constant folding

Note: For integers, these operations are performed in their signed versions.
PiperOrigin-RevId: 313918602
Change-Id: Icc1b68219351985f537ad69b45f7328f65d5d223
This commit is contained in:
Natasha Kononenko 2020-05-30 02:28:57 -07:00 committed by TensorFlower Gardener
parent 64376b5432
commit e9077ee047
3 changed files with 96 additions and 0 deletions

View File

@ -1474,6 +1474,38 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return DenseElementsAttr::get(type, values);
}
template <typename T>
struct divide : std::divides<T> {};
template <>
struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
};
template <typename T>
struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
};
template <>
struct max<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smax(a, b);
}
};
template <typename T>
struct min {
T operator()(const T& a, const T& b) const { return std::min<T>(a, b); }
};
template <>
struct min<APInt> {
APInt operator()(const APInt& a, const APInt& b) const {
return llvm::APIntOps::smin(a, b);
}
};
#define BINARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
@ -1483,9 +1515,16 @@ static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) {
return {}; \
}
// Addition, subtraction and multiplication use the std:: versions of the ops.
// Due to the other ops behaving differently in signed vs unsigned integers,
// APInts need a special implementation. Currently, it replicates signed int
// op behavior.
BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min);
#undef BINARY_FOLDER

View File

@ -303,14 +303,17 @@ def HLO_ComplexOp: HLO_Op<"complex",
def HLO_DivOp : HLO_BinaryElementwiseOp<"divide",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp {
let hasFolder = 1;
}
def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum",
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp {
let hasFolder = 1;
}
def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum",
[Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp {
let hasFolder = 1;
}
def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply",

View File

@ -45,6 +45,60 @@ func @multiply_scalar_fold() -> tensor<4xi64> {
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<1>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_fold_float
func @divide_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<7>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: max_fold_float
func @max_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: min_scalar_fold
func @min_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<-5>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: min_fold_float
func @min_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: concatenate_noop
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>