1881 lines
86 KiB
C++
1881 lines
86 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
|
#include "llvm/IR/BasicBlock.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/Intrinsics.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/lib/random/random.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
|
|
using llvm_ir::AsStringRef;
|
|
using llvm_ir::IrArray;
|
|
using llvm_ir::IrName;
|
|
using llvm_ir::SetToFirstInsertPoint;
|
|
using tensorflow::strings::StrCat;
|
|
|
|
namespace {
|
|
|
|
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
|
|
int64 mantissa_bits,
|
|
llvm::IRBuilder<>* ir_builder) {
|
|
// Integer and float types for casting and constant generation.
|
|
llvm::Type* float_type = x->getType();
|
|
llvm::IntegerType* int_type = ir_builder->getInt32Ty();
|
|
|
|
// Cast the input value to an integer for bitwise manipulation.
|
|
llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
|
|
|
|
if (mantissa_bits < 23) {
|
|
// Last remaining mantissa bit.
|
|
const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
|
|
|
|
// Compute rounding bias for round-to-nearest with ties to even. This is
|
|
// equal to a base value of 0111... plus one bit if the last remaining
|
|
// mantissa bit is 1.
|
|
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
|
|
llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
|
|
ir_builder->CreateAnd(
|
|
x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
|
|
(23 - mantissa_bits));
|
|
llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
|
|
x_last_mantissa_bit,
|
|
llvm::ConstantInt::get(int_type, base_rounding_bias));
|
|
|
|
// Add rounding bias, and mask out truncated bits. Note that the case
|
|
// where adding the rounding bias overflows into the exponent bits is
|
|
// correct; the non-masked mantissa bits will all be zero, and the
|
|
// exponent will be incremented by one.
|
|
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
|
|
x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
|
|
x_as_int = ir_builder->CreateAnd(
|
|
x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
|
|
}
|
|
|
|
if (exponent_bits < 8) {
|
|
// Masks for f32 values.
|
|
const uint32_t f32_sign_bit_mask = 1u << 31;
|
|
const uint32_t f32_exp_bits_mask = 0xffu << 23;
|
|
|
|
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
|
|
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
|
|
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
|
|
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
|
|
// exponent (corresponding to 0.0f).
|
|
//
|
|
// Thus, the f32 exponent corresponding to the highest non-infinite
|
|
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
|
|
// exponent corresponding to the lowest exponent for a bit size of n is
|
|
// (2^7-1) - 2^(n-1)-1.
|
|
//
|
|
// Note that we have already checked that exponents_bits >= 1.
|
|
const uint32_t f32_exponent_bias = (1 << 7) - 1;
|
|
const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
|
|
const uint32_t reduced_max_exponent =
|
|
f32_exponent_bias + reduced_exponent_bias;
|
|
const uint32_t reduced_min_exponent =
|
|
f32_exponent_bias - reduced_exponent_bias;
|
|
|
|
// Do we overflow or underflow?
|
|
llvm::Value* x_exponent = ir_builder->CreateAnd(
|
|
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
|
llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
|
|
x_exponent,
|
|
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
|
|
llvm::Value* x_underflows = ir_builder->CreateICmpULE(
|
|
x_exponent,
|
|
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
|
|
|
|
// Compute appropriately-signed values of zero and infinity.
|
|
llvm::Value* x_signed_zero = ir_builder->CreateAnd(
|
|
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
|
|
llvm::Value* x_signed_inf = ir_builder->CreateOr(
|
|
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
|
|
|
|
// Force to zero or infinity if overflow or underflow. (Note that this
|
|
// truncates all denormal values to zero, rather than rounding them.)
|
|
x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
|
|
x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
|
|
}
|
|
|
|
// Cast the result back to a floating-point type.
|
|
llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
|
|
|
|
// Correct result for NaN inputs.
|
|
//
|
|
// The exponent handling will "normalize" NaN values to infinities, which is
|
|
// undesirable (except in the case with no mantissa bits, in which case it
|
|
// is mandatory). This logic also handles cases where mantissa-rounding
|
|
// causes a NaN's mantissa to overflow into the exponent bits, which would
|
|
// otherwise create an erroneous zero value.
|
|
//
|
|
// If the fast-math flags are set to assume no NaNs, the comparison is likely
|
|
// to be optimized away, so there's no point in even emitting it.
|
|
if (!ir_builder->getFastMathFlags().noNaNs()) {
|
|
llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
|
|
|
|
if (mantissa_bits > 0) {
|
|
result = ir_builder->CreateSelect(x_is_nan, x, result);
|
|
} else {
|
|
result = ir_builder->CreateSelect(
|
|
x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
|
|
llvm::IRBuilder<>* ir_builder) {
|
|
auto reduced_precision = EmitReducePrecisionFloat(
|
|
f32_value,
|
|
/*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
|
|
/*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
|
|
auto as_int32 =
|
|
ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
|
|
auto shifted = ir_builder->CreateLShr(as_int32, 16);
|
|
auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
|
|
return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
|
|
}
|
|
|
|
llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
|
|
llvm::IRBuilder<>* ir_builder) {
|
|
auto as_int16 =
|
|
ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
|
|
auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
|
|
auto shifted = ir_builder->CreateShl(as_int32, 16);
|
|
return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
|
|
}
|
|
|
|
llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
|
|
PrimitiveType from_type,
|
|
PrimitiveType to_type, llvm::Module* module,
|
|
llvm::IRBuilder<>* ir_builder) {
|
|
if (primitive_util::IsSignedIntegralType(from_type)) {
|
|
return ir_builder->CreateSIToFP(
|
|
integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
|
|
} else {
|
|
CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
|
|
from_type == PRED);
|
|
return ir_builder->CreateUIToFP(
|
|
integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
|
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
|
if (op->opcode() == HloOpcode::kCopy) {
|
|
return operand_value;
|
|
} else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
|
|
op->operand(0)->shape().element_type() == PRED) {
|
|
return EmitIntegerUnaryOp(op, operand_value);
|
|
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
|
|
return EmitComplexUnaryOp(op, operand_value);
|
|
} else {
|
|
return EmitFloatUnaryOp(op, operand_value);
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
|
switch (op->opcode()) {
|
|
case HloOpcode::kConvert: {
|
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
|
PrimitiveType to_type = op->shape().element_type();
|
|
CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED);
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
if (primitive_util::IsIntegralType(to_type)) {
|
|
return ir_builder_->CreateIntCast(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
|
|
primitive_util::IsSignedIntegralType(to_type));
|
|
}
|
|
if (primitive_util::IsFloatingPointType(to_type)) {
|
|
if (to_type == BF16) {
|
|
return EmitF32ToBF16(
|
|
EmitIntegralToFloating(operand_value, from_type, F32, module_,
|
|
ir_builder_),
|
|
ir_builder_);
|
|
}
|
|
return EmitIntegralToFloating(operand_value, from_type, to_type,
|
|
module_, ir_builder_);
|
|
}
|
|
if (primitive_util::IsComplexType(to_type)) {
|
|
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
|
|
primitive_util::ComplexComponentType(to_type), module_);
|
|
if (primitive_util::IsSignedIntegralType(from_type)) {
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
|
|
nullptr);
|
|
}
|
|
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
|
from_type == PRED) {
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
|
|
nullptr);
|
|
}
|
|
}
|
|
return Unimplemented("conversion from primitive type %s to %s",
|
|
PrimitiveType_Name(from_type).c_str(),
|
|
PrimitiveType_Name(to_type).c_str());
|
|
}
|
|
case HloOpcode::kBitcastConvert: {
|
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
|
PrimitiveType to_type = op->shape().element_type();
|
|
CHECK(primitive_util::IsIntegralType(from_type));
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
if (primitive_util::BitWidth(from_type) ==
|
|
primitive_util::BitWidth(to_type)) {
|
|
return ir_builder_->CreateBitCast(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
|
}
|
|
return InvalidArgument(
|
|
"bitcast conversion from primitive type %s to %s with unequal "
|
|
"bit-widths (%u versus %u) ",
|
|
PrimitiveType_Name(from_type).c_str(),
|
|
PrimitiveType_Name(to_type).c_str(),
|
|
primitive_util::BitWidth(from_type),
|
|
primitive_util::BitWidth(to_type));
|
|
}
|
|
case HloOpcode::kAbs: {
|
|
bool is_signed =
|
|
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
|
if (is_signed) {
|
|
auto type =
|
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
|
auto zero = llvm::ConstantInt::get(type, 0);
|
|
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
|
|
return ir_builder_->CreateSelect(cmp, operand_value,
|
|
ir_builder_->CreateNeg(operand_value));
|
|
} else {
|
|
return operand_value;
|
|
}
|
|
}
|
|
case HloOpcode::kSign: {
|
|
bool is_signed =
|
|
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
|
auto type =
|
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
|
auto zero = llvm::ConstantInt::get(type, 0);
|
|
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
|
|
if (is_signed) {
|
|
auto ashr = ir_builder_->CreateAShr(operand_value,
|
|
type->getIntegerBitWidth() - 1);
|
|
return ir_builder_->CreateSelect(cmp, zero,
|
|
ir_builder_->CreateOr(ashr, 1));
|
|
} else {
|
|
return ir_builder_->CreateSelect(cmp, zero,
|
|
llvm::ConstantInt::get(type, 1));
|
|
}
|
|
}
|
|
case HloOpcode::kNegate:
|
|
return ir_builder_->CreateNeg(operand_value);
|
|
case HloOpcode::kNot: {
|
|
auto type = op->shape().element_type();
|
|
if (type == PRED) {
|
|
// It is not sufficient to just call CreateNot() here because a PRED
|
|
// is represented as an i8 and the truth value is stored only in the
|
|
// bottom bit.
|
|
return ir_builder_->CreateZExt(
|
|
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
|
|
operand_value, ir_builder_->getInt1Ty())),
|
|
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
|
} else if (primitive_util::IsIntegralType(type)) {
|
|
return ir_builder_->CreateNot(operand_value);
|
|
}
|
|
return Unimplemented("unary op Not is not defined for type '%d'", type);
|
|
}
|
|
default:
|
|
return Unimplemented("unary integer op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
|
switch (op->opcode()) {
|
|
case HloOpcode::kConvert: {
|
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
|
PrimitiveType to_type = op->shape().element_type();
|
|
CHECK(primitive_util::IsFloatingPointType(from_type));
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
if (primitive_util::IsComplexType(to_type)) {
|
|
PrimitiveType to_component_type =
|
|
primitive_util::ComplexComponentType(to_type);
|
|
if (from_type == to_component_type) {
|
|
return EmitComposeComplex(op, operand_value, nullptr);
|
|
}
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFPCast(
|
|
operand_value,
|
|
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
|
|
nullptr);
|
|
}
|
|
if (from_type == BF16) {
|
|
TF_RET_CHECK(to_type != BF16);
|
|
operand_value = EmitBF16ToF32(operand_value, ir_builder_);
|
|
from_type = F32;
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
}
|
|
if (from_type == F32 && to_type == BF16) {
|
|
return EmitF32ToBF16(operand_value, ir_builder_);
|
|
}
|
|
if (primitive_util::IsFloatingPointType(to_type)) {
|
|
return ir_builder_->CreateFPCast(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
|
}
|
|
if (primitive_util::IsSignedIntegralType(to_type)) {
|
|
return ir_builder_->CreateFPToSI(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
|
}
|
|
if (primitive_util::IsUnsignedIntegralType(to_type)) {
|
|
return ir_builder_->CreateFPToUI(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
|
}
|
|
return Unimplemented("unhandled conversion operation: %s => %s",
|
|
PrimitiveType_Name(from_type).c_str(),
|
|
PrimitiveType_Name(to_type).c_str());
|
|
}
|
|
case HloOpcode::kBitcastConvert: {
|
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
|
PrimitiveType to_type = op->shape().element_type();
|
|
CHECK(primitive_util::IsFloatingPointType(from_type));
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
if (primitive_util::BitWidth(from_type) ==
|
|
primitive_util::BitWidth(to_type)) {
|
|
return ir_builder_->CreateBitCast(
|
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
|
}
|
|
return InvalidArgument(
|
|
"bitcast conversion from primitive type %s to %s with unequal "
|
|
"bit-widths (%u versus %u) ",
|
|
PrimitiveType_Name(from_type).c_str(),
|
|
PrimitiveType_Name(to_type).c_str(),
|
|
primitive_util::BitWidth(from_type),
|
|
primitive_util::BitWidth(to_type));
|
|
}
|
|
case HloOpcode::kExp:
|
|
return EmitExp(op->shape().element_type(), operand_value);
|
|
case HloOpcode::kLog:
|
|
return EmitLog(op->shape().element_type(), operand_value);
|
|
case HloOpcode::kCos:
|
|
return EmitCos(op->shape().element_type(), operand_value);
|
|
case HloOpcode::kSin:
|
|
return EmitSin(op->shape().element_type(), operand_value);
|
|
case HloOpcode::kFloor:
|
|
return llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
|
|
ir_builder_);
|
|
case HloOpcode::kCeil:
|
|
return llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
|
|
ir_builder_);
|
|
case HloOpcode::kAbs:
|
|
return llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
|
|
ir_builder_);
|
|
case HloOpcode::kRoundNearestAfz:
|
|
return llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
|
|
ir_builder_);
|
|
case HloOpcode::kSign: {
|
|
// TODO(b/32151903): Ensure consistent sign behavior for -0.0
|
|
auto type = operand_value->getType();
|
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
|
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
|
|
auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
|
|
return ir_builder_->CreateSelect(
|
|
oeq, zero,
|
|
ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
|
|
llvm::ConstantFP::get(type, 1.0)));
|
|
}
|
|
case HloOpcode::kIsFinite: {
|
|
// (x == x) && abs(x) != inf
|
|
auto type = operand_value->getType();
|
|
auto equal_self =
|
|
ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
|
|
auto abs_value = llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
|
|
auto infinity = llvm::ConstantFP::getInfinity(type);
|
|
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
|
|
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
|
|
return ir_builder_->CreateZExt(
|
|
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
|
}
|
|
case HloOpcode::kNegate:
|
|
return ir_builder_->CreateFNeg(operand_value);
|
|
default:
|
|
return Unimplemented("unary floating-point op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
|
PrimitiveType input_type = op->operand(0)->shape().element_type();
|
|
PrimitiveType component_type =
|
|
primitive_util::IsComplexType(input_type)
|
|
? primitive_util::ComplexComponentType(input_type)
|
|
: input_type;
|
|
switch (op->opcode()) {
|
|
case HloOpcode::kLog: {
|
|
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
|
auto a = EmitExtractReal(operand_value);
|
|
auto b = EmitExtractImag(operand_value);
|
|
llvm::Type* llvm_ty = a->getType();
|
|
auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
|
|
ir_builder_->CreateFMul(b, b));
|
|
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
|
|
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
|
|
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
|
|
return EmitComposeComplex(
|
|
op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
|
|
}
|
|
case HloOpcode::kConvert: {
|
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
|
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
|
|
PrimitiveType to_type = op->shape().element_type();
|
|
TF_RET_CHECK(primitive_util::IsComplexType(to_type));
|
|
if (from_type == to_type) {
|
|
return operand_value;
|
|
}
|
|
PrimitiveType to_component_type =
|
|
primitive_util::ComplexComponentType(to_type);
|
|
auto to_ir_component_type =
|
|
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFPCast(EmitExtractReal(operand_value),
|
|
to_ir_component_type),
|
|
ir_builder_->CreateFPCast(EmitExtractImag(operand_value),
|
|
to_ir_component_type));
|
|
}
|
|
case HloOpcode::kExp: {
|
|
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
|
TF_ASSIGN_OR_RETURN(
|
|
auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
|
|
TF_ASSIGN_OR_RETURN(
|
|
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
|
|
TF_ASSIGN_OR_RETURN(
|
|
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
|
|
return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
|
|
ir_builder_->CreateFMul(exp_a, sin_b));
|
|
}
|
|
case HloOpcode::kCos: {
|
|
// cos(z) = .5(e^(iz) + e^(-iz))
|
|
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
|
|
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
|
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
|
|
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
|
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
|
|
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
|
|
auto a = EmitExtractReal(operand_value);
|
|
auto b = EmitExtractImag(operand_value);
|
|
auto type = a->getType();
|
|
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
|
|
auto half_exp_b =
|
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
|
auto half_exp_neg_b =
|
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
|
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
|
|
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFMul(
|
|
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
|
|
ir_builder_->CreateFMul(
|
|
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
|
|
}
|
|
case HloOpcode::kSin: {
|
|
// sin(z) = .5i(e^(-iz) - e^(iz))
|
|
// sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
|
|
// = .5i(e^(b-ai) - e^(-b+ai))
|
|
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
|
// sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
|
|
// = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
|
|
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
|
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
|
|
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
|
|
auto a = EmitExtractReal(operand_value);
|
|
auto b = EmitExtractImag(operand_value);
|
|
auto type = a->getType();
|
|
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
|
|
auto half_exp_b =
|
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
|
auto half_exp_neg_b =
|
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
|
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
|
|
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFMul(
|
|
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
|
|
ir_builder_->CreateFMul(
|
|
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
|
|
}
|
|
case HloOpcode::kTanh: {
|
|
/*
|
|
tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
|
|
e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
|
so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
|
|
(((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
|
|
cos(b)=cos(-b), sin(-b)=-sin(b)
|
|
so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
|
|
(((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
|
|
=(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
|
|
(cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
|
|
=(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
|
|
(cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
|
|
This is a complex division, so we can multiply by denom_conj/denom_conj
|
|
=(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
|
|
(cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
|
|
((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
|
|
=(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
|
|
i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
|
|
((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
|
|
*/
|
|
auto a = EmitExtractReal(operand_value);
|
|
auto b = EmitExtractImag(operand_value);
|
|
TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
|
|
TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
|
|
TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
|
|
auto exp_neg_a = ir_builder_->CreateFDiv(
|
|
llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
|
|
auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
|
|
ir_builder_->CreateFMul(exp_a, exp_a),
|
|
ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
|
|
auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
|
|
auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
|
|
auto real_num = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
|
|
ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
|
|
auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
|
|
auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
|
|
auto exp_a_plus_exp_neg_a_sq =
|
|
ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
|
|
auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
|
|
auto exp_a_minus_exp_neg_a_sq =
|
|
ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
|
|
auto imag_num = ir_builder_->CreateFMul(
|
|
cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
|
|
exp_a_minus_exp_neg_a_sq));
|
|
auto denom = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
|
|
ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
|
|
return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
|
|
ir_builder_->CreateFDiv(imag_num, denom));
|
|
}
|
|
case HloOpcode::kAbs: {
|
|
auto sum_sq = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(operand_value),
|
|
EmitExtractReal(operand_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(operand_value),
|
|
EmitExtractImag(operand_value)));
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
|
|
{sum_sq->getType()}, ir_builder_);
|
|
}
|
|
case HloOpcode::kSign: { // Sign(c) = c / |c|
|
|
auto sum_sq = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(operand_value),
|
|
EmitExtractReal(operand_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(operand_value),
|
|
EmitExtractImag(operand_value)));
|
|
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
|
|
auto type = cplx_abs->getType();
|
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
|
auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
|
|
return ir_builder_->CreateSelect(
|
|
oeq, EmitComposeComplex(op, zero, zero),
|
|
EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
|
|
ir_builder_->CreateFDiv(EmitExtractImag(operand_value),
|
|
cplx_abs)));
|
|
}
|
|
case HloOpcode::kNegate:
|
|
return EmitComposeComplex(
|
|
op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)),
|
|
ir_builder_->CreateFNeg(EmitExtractImag(operand_value)));
|
|
case HloOpcode::kReal:
|
|
return EmitExtractReal(operand_value);
|
|
case HloOpcode::kImag:
|
|
return EmitExtractImag(operand_value);
|
|
default:
|
|
return Unimplemented("unary complex op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
|
|
const HloInstruction* op, llvm::Value* lhs_value,
|
|
llvm::Value* rhs_value) const {
|
|
PrimitiveType operand_type = op->operand(0)->shape().element_type();
|
|
if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
|
|
operand_type == PRED) {
|
|
return EmitIntegerBinaryOp(
|
|
op, lhs_value, rhs_value,
|
|
primitive_util::IsSignedIntegralType(operand_type));
|
|
} else if (primitive_util::IsComplexType(operand_type)) {
|
|
return EmitComplexBinaryOp(op, lhs_value, rhs_value);
|
|
} else {
|
|
return EmitFloatBinaryOp(op, lhs_value, rhs_value);
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
|
const HloInstruction* op, llvm::Value* lhs_value,
|
|
llvm::Value* rhs_value) const {
|
|
switch (op->opcode()) {
|
|
case HloOpcode::kComplex:
|
|
return EmitComposeComplex(op, lhs_value, rhs_value);
|
|
case HloOpcode::kAdd:
|
|
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
|
|
case HloOpcode::kSubtract:
|
|
return ir_builder_->CreateFSub(lhs_value, rhs_value);
|
|
case HloOpcode::kMultiply:
|
|
return ir_builder_->CreateFMul(lhs_value, rhs_value);
|
|
case HloOpcode::kDivide:
|
|
return ir_builder_->CreateFDiv(lhs_value, rhs_value);
|
|
case HloOpcode::kRemainder:
|
|
return ir_builder_->CreateFRem(lhs_value, rhs_value);
|
|
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
|
|
// comparisons always return false when one of the operands is NaN, whereas
|
|
// unordered comparisons return true.
|
|
//
|
|
// We use ordered comparisons for everything except kNe, where we use an
|
|
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
|
// matches C++'s semantics.
|
|
case HloOpcode::kEq:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kNe:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kLt:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kGt:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kLe:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kGe:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
|
|
case HloOpcode::kMaximum:
|
|
return EmitFloatMax(lhs_value, rhs_value);
|
|
case HloOpcode::kMinimum:
|
|
return EmitFloatMin(lhs_value, rhs_value);
|
|
case HloOpcode::kPower:
|
|
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
|
|
case HloOpcode::kAtan2:
|
|
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
|
|
default:
|
|
return Unimplemented("binary floating point op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
|
const HloInstruction* op, llvm::Value* lhs_value,
|
|
llvm::Value* rhs_value) const {
|
|
switch (op->opcode()) {
|
|
case HloOpcode::kAdd:
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFAdd(EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFAdd(EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value)));
|
|
case HloOpcode::kSubtract:
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFSub(EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFSub(EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value)));
|
|
case HloOpcode::kMultiply:
|
|
return EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFSub(
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value))),
|
|
ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractImag(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractReal(rhs_value))));
|
|
case HloOpcode::kDivide: {
|
|
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
|
|
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
|
|
auto rhs_sum_sq = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(rhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(rhs_value),
|
|
EmitExtractImag(rhs_value)));
|
|
auto type = rhs_sum_sq->getType();
|
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
|
auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
|
|
auto real_inf_or_nan =
|
|
ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero);
|
|
auto imag_inf_or_nan =
|
|
ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero);
|
|
return ir_builder_->CreateSelect(
|
|
oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
|
|
EmitComposeComplex(
|
|
op,
|
|
ir_builder_->CreateFDiv(
|
|
ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value))),
|
|
rhs_sum_sq),
|
|
ir_builder_->CreateFDiv(
|
|
ir_builder_->CreateFSub(
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractImag(rhs_value))),
|
|
rhs_sum_sq)));
|
|
}
|
|
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
|
|
// comparisons always return false when one of the operands is NaN, whereas
|
|
// unordered comparisons return true.
|
|
//
|
|
// We use ordered comparisons for everything except kNe, where we use an
|
|
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
|
// matches C++'s semantics.
|
|
case HloOpcode::kEq:
|
|
return ir_builder_->CreateAnd(
|
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
|
EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value), ir_builder_),
|
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
|
EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value), ir_builder_));
|
|
case HloOpcode::kNe:
|
|
return ir_builder_->CreateOr(
|
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
|
EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value), ir_builder_),
|
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
|
EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value), ir_builder_));
|
|
|
|
case HloOpcode::kPower: {
|
|
// (a+bi)^(c+di) =
|
|
// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
|
|
// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
|
|
PrimitiveType component_type =
|
|
primitive_util::ComplexComponentType(op->shape().element_type());
|
|
auto a = EmitExtractReal(lhs_value);
|
|
auto b = EmitExtractImag(lhs_value);
|
|
auto c = EmitExtractReal(rhs_value);
|
|
auto d = EmitExtractImag(rhs_value);
|
|
auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
|
|
ir_builder_->CreateFMul(b, b));
|
|
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
|
|
auto half_c = ir_builder_->CreateFMul(one_half, c);
|
|
|
|
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
|
|
EmitPow(component_type, aa_p_bb, half_c));
|
|
auto neg_d = ir_builder_->CreateFNeg(d);
|
|
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
|
|
auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
|
|
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
|
|
EmitExp(component_type, neg_d_arg_lhs));
|
|
auto coeff =
|
|
ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
|
|
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
|
|
auto half_d = ir_builder_->CreateFMul(one_half, d);
|
|
auto q =
|
|
ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
|
|
ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
|
|
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
|
|
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
|
|
return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
|
|
ir_builder_->CreateFMul(coeff, sin_q));
|
|
}
|
|
default:
|
|
return Unimplemented("binary complex op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
|
llvm::Value* rhs_value) const {
|
|
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
|
|
}
|
|
|
|
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
|
llvm::Value* rhs_value) const {
|
|
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
|
llvm::Value* x) const {
|
|
if (prim_type != F32) {
|
|
return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
|
|
}
|
|
auto getFloat = [&](const float f) {
|
|
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
|
|
};
|
|
auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
|
|
llvm::Value* w) {
|
|
llvm::Value* p = getFloat(coefficients.front());
|
|
coefficients.pop_front();
|
|
for (float coefficient : coefficients) {
|
|
p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
|
|
getFloat(coefficient));
|
|
}
|
|
return p;
|
|
};
|
|
|
|
// Approximation for inverse error function from
|
|
// Giles, M., "Approximating the erfinv function".
|
|
// The approximation has the form:
|
|
// w = log((1-x)*(1+x))
|
|
// if ( w < 5 ) {
|
|
// w = w - 2.5
|
|
// p = sum_{i=1}^n lq[i]*w^i
|
|
// } else {
|
|
// w = sqrt(w) - 3
|
|
// p = sum_{i=1}^n gq[i]*w^i
|
|
// }
|
|
// return p*x
|
|
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
|
|
module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
|
|
|
|
llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
|
|
logf_fn,
|
|
{ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
|
|
ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
|
|
|
|
llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
|
ir_builder_->getFloatTy(), "p.addr", ir_builder_);
|
|
|
|
llvm_ir::LlvmIfData if_data =
|
|
llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
|
|
"w_less_than_five", ir_builder_);
|
|
// Handle true BB.
|
|
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
|
|
{
|
|
llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
|
|
tensorflow::gtl::ArraySlice<float> lq{
|
|
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
|
|
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
|
|
-0.00417768164f, 0.246640727f, 1.50140941f};
|
|
llvm::Value* p = multiply_add(lq, lw);
|
|
ir_builder_->CreateStore(p, p_addr);
|
|
}
|
|
|
|
// Handle false BB.
|
|
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
|
|
{
|
|
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
|
|
module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
|
|
|
|
llvm::Value* gw = ir_builder_->CreateFSub(
|
|
ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
|
|
tensorflow::gtl::ArraySlice<float> gq{
|
|
-0.000200214257f, 0.000100950558f, 0.00134934322f,
|
|
-0.00367342844f, 0.00573950773f, -0.0076224613f,
|
|
0.00943887047f, 1.00167406f, 2.83297682f};
|
|
llvm::Value* p = multiply_add(gq, gw);
|
|
ir_builder_->CreateStore(p, p_addr);
|
|
}
|
|
|
|
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
|
llvm::Value* p = ir_builder_->CreateLoad(p_addr);
|
|
return ir_builder_->CreateFMul(p, x);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
|
|
PrimitiveType prim_type, llvm::Value* value) const {
|
|
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
|
|
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
|
auto one = llvm::ConstantFP::get(type, 1.0);
|
|
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
|
|
llvm::Value* value) const {
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
|
|
{value->getType()}, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
|
|
llvm::Value* value) const {
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
|
|
{value->getType()}, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
|
|
llvm::Value* value) const {
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
|
|
{value->getType()}, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
|
|
llvm::Value* value) const {
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
|
|
{value->getType()}, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
|
|
llvm::Value* lhs,
|
|
llvm::Value* rhs) const {
|
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
|
|
{lhs->getType()}, ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
|
llvm::Value* lhs,
|
|
llvm::Value* rhs) const {
|
|
return Unimplemented("atan2");
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
|
|
const HloInstruction* hlo, llvm::Value* x) const {
|
|
if (hlo->operand(0)->shape().element_type() != F32) {
|
|
return Unimplemented("reduce-precision only implemented for F32");
|
|
}
|
|
return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
|
|
/*mantissa_bits=*/hlo->mantissa_bits(),
|
|
ir_builder_);
|
|
}
|
|
|
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
|
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
|
|
bool is_signed) const {
|
|
switch (op->opcode()) {
|
|
// TODO(jingyue): add the "nsw" attribute for signed types.
|
|
case HloOpcode::kAdd:
|
|
return ir_builder_->CreateAdd(lhs_value, rhs_value);
|
|
case HloOpcode::kSubtract:
|
|
return ir_builder_->CreateSub(lhs_value, rhs_value);
|
|
case HloOpcode::kMultiply:
|
|
return ir_builder_->CreateMul(lhs_value, rhs_value);
|
|
case HloOpcode::kDivide:
|
|
return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
|
|
: ir_builder_->CreateUDiv(lhs_value, rhs_value);
|
|
case HloOpcode::kRemainder:
|
|
return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
|
|
: ir_builder_->CreateURem(lhs_value, rhs_value);
|
|
case HloOpcode::kEq:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kNe:
|
|
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
|
|
rhs_value, ir_builder_);
|
|
case HloOpcode::kLt:
|
|
return llvm_ir::EmitComparison(
|
|
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
|
|
lhs_value, rhs_value, ir_builder_);
|
|
case HloOpcode::kGt:
|
|
return llvm_ir::EmitComparison(
|
|
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
|
|
lhs_value, rhs_value, ir_builder_);
|
|
case HloOpcode::kLe:
|
|
return llvm_ir::EmitComparison(
|
|
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
|
|
lhs_value, rhs_value, ir_builder_);
|
|
case HloOpcode::kGe:
|
|
return llvm_ir::EmitComparison(
|
|
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
|
lhs_value, rhs_value, ir_builder_);
|
|
case HloOpcode::kMinimum:
|
|
return ir_builder_->CreateSelect(
|
|
ir_builder_->CreateICmp(
|
|
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
|
lhs_value, rhs_value),
|
|
lhs_value, rhs_value);
|
|
case HloOpcode::kMaximum:
|
|
return ir_builder_->CreateSelect(
|
|
ir_builder_->CreateICmp(
|
|
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
|
lhs_value, rhs_value),
|
|
lhs_value, rhs_value);
|
|
case HloOpcode::kAnd:
|
|
return ir_builder_->CreateAnd(lhs_value, rhs_value);
|
|
case HloOpcode::kOr:
|
|
return ir_builder_->CreateOr(lhs_value, rhs_value);
|
|
case HloOpcode::kShiftLeft:
|
|
return ir_builder_->CreateShl(lhs_value, rhs_value);
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
return ir_builder_->CreateAShr(lhs_value, rhs_value);
|
|
case HloOpcode::kShiftRightLogical:
|
|
return ir_builder_->CreateLShr(lhs_value, rhs_value);
|
|
default:
|
|
return Unimplemented("binary integer op '%s'",
|
|
HloOpcodeString(op->opcode()).c_str());
|
|
}
|
|
}
|
|
|
|
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
|
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
|
|
int64 operand_no) const {
|
|
CHECK(hlo.IsElementwise())
|
|
<< "HLO " << hlo.ToString() << " is not elementwise.";
|
|
|
|
const Shape& operand_shape = hlo.operand(operand_no)->shape();
|
|
// If the operand is scalar, the source index is always {}.
|
|
if (ShapeUtil::IsScalar(operand_shape)) {
|
|
return llvm_ir::IrArray::Index();
|
|
}
|
|
|
|
// If no implicit broadcast is needed for this operand, returns the target
|
|
// index as the source index.
|
|
if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
|
|
return target_index;
|
|
}
|
|
|
|
// If implicit broadcast is needed, the source dimensions that are broadcast
|
|
// have index 0.
|
|
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
|
|
llvm_ir::IrArray::Index source_index;
|
|
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
|
|
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
|
|
source_index.push_back(target_index[i]);
|
|
} else {
|
|
CHECK_EQ(1, operand_shape.dimensions(i));
|
|
source_index.push_back(ir_builder_->getInt64(0));
|
|
}
|
|
}
|
|
return source_index;
|
|
}
|
|
|
|
llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
|
const HloInstruction* hlo,
|
|
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
|
|
const {
|
|
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
|
|
llvm::Type* param_ir_type =
|
|
llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
|
|
|
|
// Same values as PCG library
|
|
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
|
|
llvm::Value* multiplier = ir_builder_->getInt(
|
|
llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
|
|
llvm::Value* increment = ir_builder_->getInt(
|
|
llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
|
|
|
|
auto random_value = [hlo]() {
|
|
const HloModule* module =
|
|
hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
|
|
: hlo->parent()->parent();
|
|
return module->RandomNew64();
|
|
};
|
|
|
|
// Seed each RNG emitter with a new 64-bit seed from the HloModule. If the
|
|
// compilation order is deterministic (i.e., RandomNew64 invocation order is
|
|
// deterministic), then the order of RNG is deterministic for a given seed and
|
|
// hence tests will be deterministic.
|
|
// If the user provides a global seed instruction then we only use 64-bits of
|
|
// the host's random number generator to seed the 128 bit value with the other
|
|
// 64-bits is due to a user specified global seed instruction.
|
|
// Create a GlobalVariable to maintain state between invocations. There is a
|
|
// bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit
|
|
// values.
|
|
llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
|
|
/*M=*/*module_,
|
|
/*Ty=*/ir_builder_->getInt64Ty(),
|
|
/*isConstant=*/false,
|
|
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
|
|
/*Initializer=*/ir_builder_->getInt64(random_value()),
|
|
/*Name=*/"state_ptr0");
|
|
uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
|
|
: random_value();
|
|
llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
|
|
/*M=*/*module_,
|
|
/*Ty=*/ir_builder_->getInt64Ty(),
|
|
/*isConstant=*/false,
|
|
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
|
|
/*Initializer=*/ir_builder_->getInt64(graph_seed),
|
|
/*Name=*/"state_ptr1");
|
|
|
|
// We want each thread to use its own stream, so we modify the increment per
|
|
// thread. We want the increment to remain odd, so we shift the thread id left
|
|
// 1 and add it to the increment.
|
|
increment = ir_builder_->CreateAdd(increment,
|
|
ir_builder_->CreateShl(EmitThreadId(), 1));
|
|
|
|
// PCG-XSL-RR algorithm
|
|
// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
|
|
// state = multiplier * state + increment
|
|
// return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
|
|
// where ">>>" is bitwise rotation
|
|
auto get_next_i64 = [=]() {
|
|
llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
|
|
ir_builder_->CreateLoad(state_ptr0, "state0"),
|
|
ir_builder_->getInt128Ty());
|
|
llvm::Value* state1 = ir_builder_->CreateShl(
|
|
ir_builder_->CreateZExtOrTrunc(
|
|
ir_builder_->CreateLoad(state_ptr1, "state1"),
|
|
ir_builder_->getInt128Ty()),
|
|
64);
|
|
llvm::Value* state = ir_builder_->CreateOr(state0, state1);
|
|
llvm::Value* updated = ir_builder_->CreateAdd(
|
|
ir_builder_->CreateMul(state, multiplier), increment);
|
|
ir_builder_->CreateStore(
|
|
ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
|
|
state_ptr0);
|
|
ir_builder_->CreateStore(
|
|
ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
|
|
ir_builder_->getInt64Ty()),
|
|
state_ptr1);
|
|
|
|
return llvm_ir::CreateRor(
|
|
ir_builder_->CreateTrunc(
|
|
ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
|
|
ir_builder_->getInt64Ty()),
|
|
ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
|
|
ir_builder_->getInt64Ty()),
|
|
ir_builder_);
|
|
};
|
|
|
|
auto get_next_uniform_float = [=]() {
|
|
return ir_builder_->CreateFDiv(
|
|
ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
|
|
llvm::ConstantFP::get(param_ir_type, 0x1p64));
|
|
};
|
|
|
|
return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
switch (hlo->random_distribution()) {
|
|
case RNG_UNIFORM: {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * p,
|
|
operand_to_generator.at(hlo->operand(0))(index));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * q,
|
|
operand_to_generator.at(hlo->operand(1))(index));
|
|
if (primitive_util::IsFloatingPointType(param_prim_type)) {
|
|
return ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
|
|
get_next_uniform_float()),
|
|
p);
|
|
} else {
|
|
auto r = ir_builder_->CreateSub(q, p);
|
|
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
|
|
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
|
|
{param_ir_type}, ir_builder_);
|
|
auto in_block = ir_builder_->GetInsertBlock();
|
|
|
|
// A terminator should be present iff we're emitting code
|
|
// into the middle (as opposed to the end) of a basic block.
|
|
CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(),
|
|
in_block->getTerminator() == nullptr);
|
|
|
|
llvm::BasicBlock* body_block;
|
|
llvm::BasicBlock* out_block;
|
|
|
|
if (ir_builder_->GetInsertPoint() == in_block->end()) {
|
|
body_block = llvm_ir::CreateBasicBlock(
|
|
nullptr, IrName(hlo, "rng_body"), ir_builder_);
|
|
out_block = llvm_ir::CreateBasicBlock(
|
|
nullptr, IrName(hlo, "rng_out"), ir_builder_);
|
|
llvm::BranchInst::Create(body_block, in_block);
|
|
} else {
|
|
body_block = in_block->splitBasicBlock(
|
|
ir_builder_->GetInsertPoint(), "rng_body");
|
|
out_block = body_block->splitBasicBlock(
|
|
ir_builder_->GetInsertPoint(), "rng_out");
|
|
body_block->getTerminator()->eraseFromParent();
|
|
}
|
|
|
|
SetToFirstInsertPoint(body_block, ir_builder_);
|
|
auto random = ir_builder_->CreateAnd(
|
|
ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
|
|
ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
|
|
leading_zeros));
|
|
llvm::BranchInst::Create(out_block, body_block,
|
|
ir_builder_->CreateICmpULT(random, r),
|
|
body_block);
|
|
SetToFirstInsertPoint(out_block, ir_builder_);
|
|
return ir_builder_->CreateAdd(
|
|
p, ir_builder_->CreateSelect(
|
|
ir_builder_->CreateICmpEQ(p, q),
|
|
llvm::ConstantInt::get(param_ir_type, 0), random));
|
|
}
|
|
}
|
|
case RNG_NORMAL: {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * m,
|
|
operand_to_generator.at(hlo->operand(0))(index));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * s,
|
|
operand_to_generator.at(hlo->operand(1))(index));
|
|
TF_ASSIGN_OR_RETURN(
|
|
llvm::Value * r,
|
|
EmitErfcInv(param_prim_type,
|
|
ir_builder_->CreateFMul(
|
|
llvm::ConstantFP::get(param_ir_type, 2.0),
|
|
get_next_uniform_float())));
|
|
return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
|
|
}
|
|
case RNG_BERNOULLI: {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * p,
|
|
operand_to_generator.at(hlo->operand(0))(index));
|
|
PrimitiveType element_type = hlo->shape().element_type();
|
|
llvm::Value* zero;
|
|
llvm::Value* one;
|
|
llvm::Type* result_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
|
hlo->shape().element_type(), module_);
|
|
if (primitive_util::IsFloatingPointType(element_type)) {
|
|
zero = llvm::ConstantFP::get(result_ir_type, 0.0);
|
|
one = llvm::ConstantFP::get(result_ir_type, 1.0);
|
|
} else if (primitive_util::IsIntegralType(element_type)) {
|
|
zero = llvm::ConstantInt::get(result_ir_type, 0);
|
|
one = llvm::ConstantInt::get(result_ir_type, 1);
|
|
} else {
|
|
return Unimplemented(
|
|
"Rng Bernoulli unimplemented for requested type!");
|
|
}
|
|
|
|
return ir_builder_->CreateSelect(
|
|
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), one, zero);
|
|
}
|
|
default:
|
|
return InvalidArgument(
|
|
"unhandled distribution %s",
|
|
RandomDistribution_Name(hlo->random_distribution()).c_str());
|
|
}
|
|
};
|
|
}
|
|
|
|
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|
const HloInstruction* hlo,
|
|
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
|
|
const {
|
|
switch (hlo->opcode()) {
|
|
case HloOpcode::kAbs:
|
|
case HloOpcode::kRoundNearestAfz:
|
|
case HloOpcode::kCeil:
|
|
case HloOpcode::kConvert:
|
|
case HloOpcode::kBitcastConvert:
|
|
case HloOpcode::kCopy:
|
|
case HloOpcode::kCos:
|
|
case HloOpcode::kExp:
|
|
case HloOpcode::kFloor:
|
|
case HloOpcode::kImag:
|
|
case HloOpcode::kIsFinite:
|
|
case HloOpcode::kLog:
|
|
case HloOpcode::kNegate:
|
|
case HloOpcode::kNot:
|
|
case HloOpcode::kReal:
|
|
case HloOpcode::kSign:
|
|
case HloOpcode::kSin:
|
|
case HloOpcode::kTanh:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
|
operand_to_generator.at(hlo->operand(0))(
|
|
ElementwiseSourceIndex(index, *hlo, 0)));
|
|
return EmitUnaryOp(hlo, operand_value);
|
|
};
|
|
case HloOpcode::kAdd:
|
|
case HloOpcode::kAnd:
|
|
case HloOpcode::kAtan2:
|
|
case HloOpcode::kComplex:
|
|
case HloOpcode::kDivide:
|
|
case HloOpcode::kEq:
|
|
case HloOpcode::kGe:
|
|
case HloOpcode::kGt:
|
|
case HloOpcode::kLe:
|
|
case HloOpcode::kLt:
|
|
case HloOpcode::kMaximum:
|
|
case HloOpcode::kMinimum:
|
|
case HloOpcode::kMultiply:
|
|
case HloOpcode::kNe:
|
|
case HloOpcode::kOr:
|
|
case HloOpcode::kPower:
|
|
case HloOpcode::kRemainder:
|
|
case HloOpcode::kShiftLeft:
|
|
case HloOpcode::kShiftRightArithmetic:
|
|
case HloOpcode::kShiftRightLogical:
|
|
case HloOpcode::kSubtract:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
const HloInstruction* lhs = hlo->operand(0);
|
|
const HloInstruction* rhs = hlo->operand(1);
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
|
|
operand_to_generator.at(lhs)(
|
|
ElementwiseSourceIndex(index, *hlo, 0)));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
|
|
operand_to_generator.at(rhs)(
|
|
ElementwiseSourceIndex(index, *hlo, 1)));
|
|
return EmitBinaryOp(hlo, lhs_value, rhs_value);
|
|
};
|
|
case HloOpcode::kSelect:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
|
|
operand_to_generator.at(hlo->operand(0))(
|
|
ElementwiseSourceIndex(index, *hlo, 0)));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
|
|
operand_to_generator.at(hlo->operand(1))(
|
|
ElementwiseSourceIndex(index, *hlo, 1)));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
|
|
operand_to_generator.at(hlo->operand(2))(
|
|
ElementwiseSourceIndex(index, *hlo, 2)));
|
|
return ir_builder_->CreateSelect(
|
|
ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
|
|
on_true_value, on_false_value);
|
|
};
|
|
case HloOpcode::kClamp:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
|
|
operand_to_generator.at(hlo->operand(0))(
|
|
ElementwiseSourceIndex(index, *hlo, 0)));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
|
|
operand_to_generator.at(hlo->operand(1))(
|
|
ElementwiseSourceIndex(index, *hlo, 1)));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
|
|
operand_to_generator.at(hlo->operand(2))(
|
|
ElementwiseSourceIndex(index, *hlo, 2)));
|
|
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
|
|
};
|
|
case HloOpcode::kReducePrecision:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
|
operand_to_generator.at(hlo->operand(0))(
|
|
ElementwiseSourceIndex(index, *hlo, 0)));
|
|
return EmitReducePrecision(hlo, operand_value);
|
|
};
|
|
case HloOpcode::kConcatenate:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
|
|
const int64 concat_dim = hlo->dimensions(0);
|
|
auto source_index = target_index;
|
|
|
|
llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
|
|
|
|
// A terminator should be present iff we're emitting code
|
|
// into the middle (as opposed to the end) of a basic block.
|
|
CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
|
|
init_block->getTerminator() == nullptr);
|
|
|
|
llvm::BasicBlock* exit_block;
|
|
if (ir_builder_->GetInsertPoint() == init_block->end()) {
|
|
exit_block = llvm_ir::CreateBasicBlock(
|
|
/*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
|
|
} else {
|
|
exit_block = init_block->splitBasicBlock(
|
|
ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge")));
|
|
init_block->getTerminator()->eraseFromParent();
|
|
}
|
|
|
|
llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
|
|
llvm::PHINode* output =
|
|
ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
|
|
hlo->shape().element_type(), module_),
|
|
hlo->operands().size());
|
|
auto prior_insert_point = ir_builder_->GetInsertPoint();
|
|
|
|
ir_builder_->SetInsertPoint(init_block);
|
|
|
|
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
|
|
++operand_idx) {
|
|
const HloInstruction* operand = hlo->operand(operand_idx);
|
|
auto true_block = llvm_ir::CreateBasicBlock(
|
|
exit_block, StrCat("concat_index_from_operand", operand_idx),
|
|
ir_builder_);
|
|
auto false_block = llvm_ir::CreateBasicBlock(
|
|
exit_block, StrCat("concat_index_not_from_operand", operand_idx),
|
|
ir_builder_);
|
|
auto concat_dim_size =
|
|
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
|
|
operand->shape().dimensions(concat_dim));
|
|
ir_builder_->CreateCondBr(
|
|
ir_builder_->CreateICmpULT(source_index[concat_dim],
|
|
concat_dim_size),
|
|
true_block, false_block);
|
|
|
|
// Create the terminator of the true block before calling operand
|
|
// generators, because they require non-degenerate basic blocks.
|
|
ir_builder_->SetInsertPoint(
|
|
llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * value,
|
|
operand_to_generator.at(operand)(source_index));
|
|
output->addIncoming(value, ir_builder_->GetInsertBlock());
|
|
|
|
// Subtract the size of the concat dimension of the current operand
|
|
// from the source index.
|
|
ir_builder_->SetInsertPoint(false_block);
|
|
source_index[concat_dim] =
|
|
ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
|
|
}
|
|
|
|
ir_builder_->CreateUnreachable();
|
|
ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
|
|
return output;
|
|
};
|
|
case HloOpcode::kReverse:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
|
|
const HloInstruction* operand = hlo->operand(0);
|
|
auto source_index = target_index;
|
|
for (int64 dim : hlo->dimensions()) {
|
|
source_index[dim] = ir_builder_->CreateSub(
|
|
llvm::ConstantInt::get(target_index[dim]->getType(),
|
|
hlo->shape().dimensions(dim) - 1),
|
|
target_index[dim]);
|
|
}
|
|
return operand_to_generator.at(operand)(source_index);
|
|
};
|
|
case HloOpcode::kBroadcast:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
|
|
// The `dimensions` member of the broadcast instruction maps from
|
|
// input dimensions to output dimensions.
|
|
const HloInstruction* operand = hlo->operand(0);
|
|
int64 rank = ShapeUtil::Rank(operand->shape());
|
|
IrArray::Index source_index(rank);
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
source_index[i] = target_index[hlo->dimensions(i)];
|
|
}
|
|
return operand_to_generator.at(operand)(source_index);
|
|
};
|
|
case HloOpcode::kSlice:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
IrArray::Index sliced_index = index.SourceIndexOfSlice(
|
|
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
|
|
/*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
|
|
return operand_to_generator.at(hlo->operand(0))(sliced_index);
|
|
};
|
|
case HloOpcode::kDynamicSlice:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
// Emit IR to read dynamic start indices from hlo->operand(1).
|
|
const HloInstruction* input_hlo = hlo->operand(0);
|
|
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
|
|
llvm_ir::IrArray::Index slice_start_index(rank);
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
|
|
TF_ASSIGN_OR_RETURN(
|
|
llvm::Value * start_index_value,
|
|
operand_to_generator.at(hlo->operand(1))(dim_index));
|
|
start_index_value->setName(
|
|
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
|
|
slice_start_index[i] = start_index_value;
|
|
}
|
|
|
|
llvm_ir::IrArray::Index input_index(rank);
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
// Emit IR which computes:
|
|
// input_index = (start_index + offset_index) % dim_size
|
|
// Security note: this is the code that keeps the indices in-bounds.
|
|
llvm::Value* dim_size = llvm::ConstantInt::get(
|
|
index[i]->getType(), input_hlo->shape().dimensions(i));
|
|
llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
|
|
slice_start_index[i], index[i]->getType());
|
|
input_index[i] = ir_builder_->CreateURem(
|
|
ir_builder_->CreateAdd(start_index, index[i]), dim_size);
|
|
}
|
|
return operand_to_generator.at(input_hlo)(input_index);
|
|
};
|
|
case HloOpcode::kDynamicUpdateSlice:
|
|
return [this, hlo, &operand_to_generator](
|
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
const HloInstruction* input_hlo = hlo->operand(0);
|
|
const HloInstruction* update_hlo = hlo->operand(1);
|
|
const HloInstruction* start_hlo = hlo->operand(2);
|
|
// Calculate slice start/end indices.
|
|
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
|
|
llvm_ir::IrArray::Index slice_start_index(rank);
|
|
llvm_ir::IrArray::Index slice_limit_index(rank);
|
|
// Slice starts at update[index - slice_start_index_adjusted],
|
|
// where adjusted value = slice_start_index when in bounds, and
|
|
// adjusted value = slice_start_index - input_dim, when wrapping.
|
|
llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
|
|
|
|
// Slice intersection gathers (ANDs) conditions on all ranks for which
|
|
// 'input' is set to 'update'
|
|
llvm::Value* slice_intersection = ir_builder_->getTrue();
|
|
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
// Emit IR to read dynamic start indices from 'start_hlo'.
|
|
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
|
|
operand_to_generator.at(start_hlo)(dim_index));
|
|
start_index_value->setName(
|
|
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
|
|
slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
|
|
start_index_value, index[i]->getType());
|
|
|
|
llvm::Value* input_dim_size = llvm::ConstantInt::get(
|
|
index[i]->getType(), input_hlo->shape().dimensions(i));
|
|
llvm::Value* update_dim_size = llvm::ConstantInt::get(
|
|
index[i]->getType(), update_hlo->shape().dimensions(i));
|
|
|
|
// Generate code to handle wrapping semantics:
|
|
// slice_start_index[i] = slice_start_index[i] % input_dim_size;
|
|
// slice_limit_index[i] = slice_start_index[i] + update_dim_size.
|
|
// slice_start_index[i] is updated in place and it will now be in
|
|
// range. slice_limit_index[i] may be out of range, and it's being
|
|
// URem-ed below if so.
|
|
slice_start_index[i] =
|
|
ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
|
|
slice_limit_index[i] =
|
|
ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
|
|
|
|
// Test if slice_limit_index[i] is in bounds
|
|
llvm::Value* in_bounds =
|
|
ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
|
|
llvm_ir::LlvmIfData if_in_bounds =
|
|
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
|
|
|
// Handle true BB (slice_limit_index[i] <= input_dim_size).
|
|
SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
|
|
// Check that index[i] >= slice_start_index[i] &&
|
|
// index[i] < slice_limit_index[i]
|
|
llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
|
|
slice_intersection,
|
|
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
|
|
"slice_intersection_in");
|
|
slice_intersection_in_bounds = ir_builder_->CreateAnd(
|
|
slice_intersection_in_bounds,
|
|
ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
|
|
"slice_intersection_in");
|
|
|
|
// Handle false BB (slice_limit_index[i] > input_dim_size).
|
|
SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
|
|
// Check that index[i] >= slice_start_index[i] ||
|
|
// index[i] < slice_limit_index[i]%input_dim_size.
|
|
llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
|
|
index[i],
|
|
ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
|
|
llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
|
|
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
|
|
index_wraps, "slice_intersection_out");
|
|
llvm::Value* slice_intersection_out_of_bounds =
|
|
ir_builder_->CreateAnd(slice_intersection, slice_intersection_or,
|
|
"slice_intersection_out");
|
|
// Create value for slice_start_index_adjusted[i] when out of bounds.
|
|
// If within out-of-bounds if.
|
|
llvm_ir::LlvmIfData if_start_needs_adjustment =
|
|
llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
|
|
SetToFirstInsertPoint(if_start_needs_adjustment.true_block,
|
|
ir_builder_);
|
|
llvm::Value* slice_start_index_adjusted_oob =
|
|
ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
|
|
SetToFirstInsertPoint(if_start_needs_adjustment.after_block,
|
|
ir_builder_);
|
|
llvm::PHINode* slice_start_index_adjusted_phi =
|
|
ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(),
|
|
2);
|
|
slice_start_index_adjusted_phi->addIncoming(
|
|
slice_start_index_adjusted_oob,
|
|
if_start_needs_adjustment.true_block);
|
|
slice_start_index_adjusted_phi->addIncoming(
|
|
slice_start_index[i], if_start_needs_adjustment.false_block);
|
|
// End of if within if.
|
|
|
|
// After checking in/out of bounds.
|
|
SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
|
|
llvm::PHINode* phi_slice_intersection =
|
|
ir_builder_->CreatePHI(slice_intersection->getType(), 2);
|
|
phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
|
|
if_in_bounds.true_block);
|
|
phi_slice_intersection->addIncoming(
|
|
slice_intersection_out_of_bounds,
|
|
if_start_needs_adjustment.after_block);
|
|
slice_intersection = phi_slice_intersection;
|
|
|
|
llvm::PHINode* phi_index =
|
|
ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
|
|
phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
|
|
phi_index->addIncoming(slice_start_index_adjusted_phi,
|
|
if_start_needs_adjustment.after_block);
|
|
slice_start_index_adjusted[i] = phi_index;
|
|
}
|
|
|
|
// Emit:
|
|
// if (slice_intersection) -> return data from 'update'.
|
|
// else -> return data from 'input'.
|
|
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
|
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
|
module_),
|
|
"ret_value_addr", ir_builder_);
|
|
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
|
slice_intersection, "slice_intersection", ir_builder_);
|
|
|
|
// Handle true BB (return data from 'update')
|
|
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
|
|
// Compute update index for intersection case.
|
|
llvm_ir::IrArray::Index update_index(rank);
|
|
for (int64 i = 0; i < rank; ++i) {
|
|
llvm::Value* update_dim_size = llvm::ConstantInt::get(
|
|
index[i]->getType(), update_hlo->shape().dimensions(i));
|
|
// NOTE: Subtraction will be positive due to bounds checking above.
|
|
update_index[i] = ir_builder_->CreateURem(
|
|
ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
|
|
update_dim_size);
|
|
}
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
|
|
operand_to_generator.at(update_hlo)(update_index));
|
|
ir_builder_->CreateStore(true_value, ret_value_addr);
|
|
|
|
// Handle false BB (return data from 'input')
|
|
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
|
|
operand_to_generator.at(input_hlo)(index));
|
|
ir_builder_->CreateStore(false_value, ret_value_addr);
|
|
|
|
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
|
return ir_builder_->CreateLoad(ret_value_addr);
|
|
};
|
|
case HloOpcode::kReshape:
|
|
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
|
|
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
|
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
|
const HloInstruction* operand = hlo->operand(0);
|
|
return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
|
|
hlo->shape(), operand->shape(), ir_builder_));
|
|
};
|
|
case HloOpcode::kTranspose:
|
|
return [this, hlo,
|
|
&operand_to_generator](const IrArray::Index& target_index) {
|
|
return operand_to_generator.at(hlo->operand(0))(
|
|
target_index.SourceIndexOfTranspose(
|
|
hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
|
|
ir_builder_));
|
|
};
|
|
case HloOpcode::kRng:
|
|
return MakeRngElementGenerator(hlo, operand_to_generator);
|
|
case HloOpcode::kPad:
|
|
return [=, &operand_to_generator](
|
|
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
|
|
auto index = padded_index;
|
|
llvm::Value* in_bounds = ir_builder_->getTrue();
|
|
for (size_t i = 0; i < index.size(); ++i) {
|
|
auto index_typed_const = [=](int64 n) {
|
|
return llvm::ConstantInt::get(index[i]->getType(), n);
|
|
};
|
|
const auto& pad_dim = hlo->padding_config().dimensions(i);
|
|
index[i] = ir_builder_->CreateSub(
|
|
index[i], index_typed_const(pad_dim.edge_padding_low()));
|
|
in_bounds = ir_builder_->CreateAnd(
|
|
in_bounds,
|
|
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
|
|
"in_bounds");
|
|
in_bounds = ir_builder_->CreateAnd(
|
|
in_bounds,
|
|
ir_builder_->CreateICmpEQ(
|
|
index_typed_const(0),
|
|
ir_builder_->CreateURem(
|
|
index[i],
|
|
index_typed_const(pad_dim.interior_padding() + 1))),
|
|
"in_bounds");
|
|
index[i] = ir_builder_->CreateSDiv(
|
|
index[i], index_typed_const(pad_dim.interior_padding() + 1));
|
|
in_bounds = ir_builder_->CreateAnd(
|
|
in_bounds,
|
|
ir_builder_->CreateICmpSLT(
|
|
index[i],
|
|
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
|
|
"in_bounds");
|
|
}
|
|
|
|
// if (in_bounds) {
|
|
// ret_value = operand0[index]; // source
|
|
// } else {
|
|
// ret_value = *operand1; // padding
|
|
// }
|
|
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
|
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
|
module_),
|
|
"pad_result_addr", ir_builder_);
|
|
llvm_ir::LlvmIfData if_data =
|
|
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
|
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
|
operand_to_generator.at(hlo->operand(0))(index));
|
|
ir_builder_->CreateStore(operand_value, ret_value_addr);
|
|
|
|
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
|
|
operand_to_generator.at(hlo->operand(1))({}));
|
|
ir_builder_->CreateStore(padding_value, ret_value_addr);
|
|
|
|
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
|
// Don't create phi(operand_value, padding_value) here, because invoking
|
|
// operand_to_generator may create new basic blocks, making the parent
|
|
// of operand_value or padding_value no longer a predecessor of
|
|
// if_data.after_block.
|
|
return ir_builder_->CreateLoad(ret_value_addr);
|
|
};
|
|
|
|
case HloOpcode::kDot:
|
|
return [=, &operand_to_generator](const IrArray::Index& dot_result_index)
|
|
-> StatusOr<llvm::Value*> {
|
|
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
|
|
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
|
|
int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
|
|
hlo->operand(0)->shape().dimensions_size() - 1);
|
|
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
|
|
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
|
|
|
|
std::unique_ptr<llvm_ir::ForLoop> inner_loop =
|
|
llvm_ir::ForLoop::EmitForLoop(
|
|
IrName(hlo, "inner"), ir_builder_->getInt64(0),
|
|
ir_builder_->getInt64(contracted_dim_size),
|
|
ir_builder_->getInt64(1), ir_builder_);
|
|
|
|
SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(),
|
|
ir_builder_);
|
|
PrimitiveType primitive_type = hlo->shape().element_type();
|
|
llvm::Type* primitive_type_llvm =
|
|
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
|
|
llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
|
|
primitive_type_llvm, "dot_acc", ir_builder_);
|
|
ir_builder_->CreateStore(
|
|
llvm::Constant::getNullValue(primitive_type_llvm),
|
|
accumulator_alloca);
|
|
|
|
SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
|
|
|
|
// This is the inner reduction loop for a dot operation that produces
|
|
// one element in the output. If the operands to the dot operation have
|
|
// shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
|
|
// Given an output index [a,b,c,d,e] in the result, we compute:
|
|
// sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
|
|
|
|
IrArray::Index lhs_index, rhs_index;
|
|
|
|
for (int64 i = 0; i < lhs_dims - 1; i++) {
|
|
lhs_index.push_back(dot_result_index[i]);
|
|
}
|
|
lhs_index.push_back(inner_loop->GetIndVarValue());
|
|
|
|
for (int64 i = 0; i < rhs_dims - 2; i++) {
|
|
rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
|
|
}
|
|
rhs_index.push_back(inner_loop->GetIndVarValue());
|
|
rhs_index.push_back(dot_result_index.back());
|
|
|
|
llvm::Value* current_accumulator =
|
|
ir_builder_->CreateLoad(accumulator_alloca);
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
|
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
|
llvm::Value* next_accumulator;
|
|
if (primitive_util::IsComplexType(primitive_type)) {
|
|
llvm::Value* product_real = ir_builder_->CreateFSub(
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractReal(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractImag(rhs_value)));
|
|
llvm::Value* product_imag = ir_builder_->CreateFAdd(
|
|
ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
|
|
EmitExtractImag(rhs_value)),
|
|
ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
|
|
EmitExtractReal(rhs_value)));
|
|
next_accumulator = ir_builder_->CreateInsertValue(
|
|
current_accumulator,
|
|
ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
|
|
product_real),
|
|
{0});
|
|
next_accumulator = ir_builder_->CreateInsertValue(
|
|
next_accumulator,
|
|
ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
|
|
product_imag),
|
|
{1});
|
|
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
|
|
next_accumulator = ir_builder_->CreateFAdd(
|
|
current_accumulator,
|
|
ir_builder_->CreateFMul(lhs_value, rhs_value));
|
|
} else {
|
|
next_accumulator = ir_builder_->CreateAdd(
|
|
current_accumulator,
|
|
ir_builder_->CreateMul(lhs_value, rhs_value));
|
|
}
|
|
ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
|
|
|
|
SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
|
|
return ir_builder_->CreateLoad(accumulator_alloca);
|
|
};
|
|
default:
|
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
|
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
|
HloOpcodeString(hlo->opcode()).c_str());
|
|
};
|
|
}
|
|
}
|
|
|
|
llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
|
|
return ir_builder_->CreateExtractValue(value, {0});
|
|
}
|
|
|
|
llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
|
|
return ir_builder_->CreateExtractValue(value, {1});
|
|
}
|
|
|
|
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
|
|
llvm::Value* real,
|
|
llvm::Value* imag) const {
|
|
auto cplx_type =
|
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
|
auto complex = ir_builder_->CreateInsertValue(
|
|
llvm::ConstantAggregateZero::get(cplx_type), real, {0});
|
|
if (imag != nullptr) {
|
|
complex = ir_builder_->CreateInsertValue(complex, imag, {1});
|
|
}
|
|
return complex;
|
|
}
|
|
|
|
} // namespace xla
|