[XLA] Implement exponentiation by squaring for integers

PiperOrigin-RevId: 339569265
Change-Id: I3fd1f9ba2659eefd3edec0cc7d5359c34b60fc04
This commit is contained in:
George Karpenkov 2020-10-28 17:15:09 -07:00 committed by TensorFlower Gardener
parent 612a5fb91e
commit 20e248e158
3 changed files with 55 additions and 0 deletions

View File

@ -1578,6 +1578,33 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
}
llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
llvm::Value* exponent,
bool is_signed) {
// Exponentiation by squaring:
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64
// is the larget possible exponent for a 64-bit integer, and
// that's 1 << 6.
llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
llvm::Value* original_base = base;
llvm::Value* original_exponent = exponent;
// Unroll the loop at compile time.
for (int i = 0; i < bits; i++) {
accumulator =
b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
b_->CreateMul(accumulator, base), accumulator);
base = b_->CreateMul(base, base);
exponent = b_->CreateLShr(exponent, 1);
}
return b_->CreateSelect(
b_->CreateICmpSGE(original_exponent, zero), accumulator,
b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) {
@ -1627,6 +1654,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
return And(lhs_value, rhs_value);
case HloOpcode::kOr:
return Or(lhs_value, rhs_value);
case HloOpcode::kPower:
return EmitIntegerPow(lhs_value, rhs_value, is_signed);
case HloOpcode::kXor:
return Xor(lhs_value, rhs_value);

View File

@ -92,6 +92,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
bool is_signed);
llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
bool is_signed);
llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs,
bool is_signed);
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,

View File

@ -146,6 +146,30 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
ComputeAndCompareR1<bool>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
XlaBuilder builder(TestName());
XlaOp lhs =
ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
XlaOp rhs =
ConstantR1<int32>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
Pow(lhs, rhs);
std::vector<int32> expected = {1, 1, 8, 27, 64, 125, 1, -8, 9, 9765625, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
XlaBuilder builder(TestName());
XlaOp lhs = ConstantR1<int64>(&builder, {2});
XlaOp rhs = ConstantR1<int64>(&builder, {62});
Pow(lhs, rhs);
std::vector<int64> expected = {4611686018427387904};
ComputeAndCompareR1<int64>(&builder, expected, {});
}
// A non-canonical quiet NaN value.
static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);