From 20e248e1584a59d82122c58f100404a5a6bc030c Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 28 Oct 2020 17:15:09 -0700 Subject: [PATCH] [XLA] Implement exponentiation by squaring for integers PiperOrigin-RevId: 339569265 Change-Id: I3fd1f9ba2659eefd3edec0cc7d5359c34b60fc04 --- .../xla/service/elemental_ir_emitter.cc | 29 +++++++++++++++++++ .../xla/service/elemental_ir_emitter.h | 2 ++ .../xla/tests/array_elementwise_ops_test.cc | 24 +++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index d3e00d04dfd..8da04b51093 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1578,6 +1578,33 @@ llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); } +llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base, + llvm::Value* exponent, + bool is_signed) { + // Exponentiation by squaring: + // https://en.wikipedia.org/wiki/Exponentiation_by_squaring; + int bits = 6; // Everything else would overflow for any exponent > 1, as 2^64 + // is the larget possible exponent for a 64-bit integer, and + // that's 1 << 6. + llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1); + llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1); + llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0); + llvm::Value* original_base = base; + llvm::Value* original_exponent = exponent; + + // Unroll the loop at compile time. + for (int i = 0; i < bits; i++) { + accumulator = + b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one), + b_->CreateMul(accumulator, base), accumulator); + base = b_->CreateMul(base, base); + exponent = b_->CreateLShr(exponent, 1); + } + return b_->CreateSelect( + b_->CreateICmpSGE(original_exponent, zero), accumulator, + b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero)); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) { @@ -1627,6 +1654,8 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return And(lhs_value, rhs_value); case HloOpcode::kOr: return Or(lhs_value, rhs_value); + case HloOpcode::kPower: + return EmitIntegerPow(lhs_value, rhs_value, is_signed); case HloOpcode::kXor: return Xor(lhs_value, rhs_value); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 56833159647..60e25c7d8bf 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -92,6 +92,8 @@ class ElementalIrEmitter : public IrBuilderMixin { bool is_signed); llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, bool is_signed); + llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index ef4ce24a839..fc49c9249d7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -146,6 +146,30 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, IntPow) { + XlaBuilder builder(TestName()); + XlaOp lhs = + ConstantR1(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1}); + XlaOp rhs = + ConstantR1(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2}); + Pow(lhs, rhs); + + std::vector expected = {1, 1, 8, 27, 64, 125, 1, -8, 9, 9765625, 0, 1}; + + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) { + XlaBuilder builder(TestName()); + XlaOp lhs = ConstantR1(&builder, {2}); + XlaOp rhs = ConstantR1(&builder, {62}); + Pow(lhs, rhs); + + std::vector expected = {4611686018427387904}; + + ComputeAndCompareR1(&builder, expected, {}); +} + // A non-canonical quiet NaN value. static const float kNonCanonicalNaN = absl::bit_cast(0x7FD01234);