[XLA] Implement exponentiation by squaring for integers
PiperOrigin-RevId: 339569265 Change-Id: I3fd1f9ba2659eefd3edec0cc7d5359c34b60fc04
This commit is contained in:
parent
612a5fb91e
commit
20e248e158
@ -1578,6 +1578,33 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
|
||||
Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
|
||||
llvm::Value* exponent,
|
||||
bool is_signed) {
|
||||
// Exponentiation by squaring:
|
||||
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
|
||||
int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64
|
||||
// is the larget possible exponent for a 64-bit integer, and
|
||||
// that's 1 << 6.
|
||||
llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
|
||||
llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
|
||||
llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
|
||||
llvm::Value* original_base = base;
|
||||
llvm::Value* original_exponent = exponent;
|
||||
|
||||
// Unroll the loop at compile time.
|
||||
for (int i = 0; i < bits; i++) {
|
||||
accumulator =
|
||||
b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
|
||||
b_->CreateMul(accumulator, base), accumulator);
|
||||
base = b_->CreateMul(base, base);
|
||||
exponent = b_->CreateLShr(exponent, 1);
|
||||
}
|
||||
return b_->CreateSelect(
|
||||
b_->CreateICmpSGE(original_exponent, zero), accumulator,
|
||||
b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
bool is_signed) {
|
||||
@ -1627,6 +1654,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
return And(lhs_value, rhs_value);
|
||||
case HloOpcode::kOr:
|
||||
return Or(lhs_value, rhs_value);
|
||||
case HloOpcode::kPower:
|
||||
return EmitIntegerPow(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kXor:
|
||||
return Xor(lhs_value, rhs_value);
|
||||
|
||||
|
@ -92,6 +92,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
bool is_signed);
|
||||
llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
|
||||
bool is_signed);
|
||||
llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs,
|
||||
bool is_signed);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
|
@ -146,6 +146,30 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
|
||||
ComputeAndCompareR1<bool>(&builder, {}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp lhs =
|
||||
ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
|
||||
XlaOp rhs =
|
||||
ConstantR1<int32>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
|
||||
Pow(lhs, rhs);
|
||||
|
||||
std::vector<int32> expected = {1, 1, 8, 27, 64, 125, 1, -8, 9, 9765625, 0, 1};
|
||||
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp lhs = ConstantR1<int64>(&builder, {2});
|
||||
XlaOp rhs = ConstantR1<int64>(&builder, {62});
|
||||
Pow(lhs, rhs);
|
||||
|
||||
std::vector<int64> expected = {4611686018427387904};
|
||||
|
||||
ComputeAndCompareR1<int64>(&builder, expected, {});
|
||||
}
|
||||
|
||||
// A non-canonical quiet NaN value.
|
||||
static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user