STT-tensorflow/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
A. Unique TensorFlower f3c89936e9 Merge changes from github.
END_PUBLIC

Note: this CL will break builds.  cl/159887762 to follow to fix all the breakages.

---
Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com>
Committed by gunan<gunan@google.com>:
Updated link to use HTTPS ()

Howdy!

I just updated a link to use https instead of http.

Thanks!
---
Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] Fixes run_metadata_test for SYCL

 This test is designed to test CUDA specific behavior

---
Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Update comments
---
Commit 1699d904a authored by John Lawson<john@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] Fixes CUDA specific test run on SYCL ()

The testBadParentValuesOnGPU should only be run on CUDA devices, as the
test checks for particular CUDA behaviour. We don't actually provide a
SYCL kernel for GatherTree and so it's not a problem that the tests
don't target SYCL.
---
Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com>
Committed by Shanqing Cai<cais@google.com>:
Java API to get the size of specified input list of operations. ()

* Java API to get the size of specified input list of operations

* remove unnecessary explain to avoid bring a new term to users.

---
Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com>
Committed by Luke Iwanski<luke@codeplay.com>:
[OpenCL] REGISTER -> REGISTER6

---
Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com>
Committed by superryanguo<superryanguo@gmail.com>:
Simplify the Quickstart section with the weblink is better

---
Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr>
Committed by Taehoon Lee<taehoonlee@snu.ac.kr>:
Fix typos

---
Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com>
Committed by GitHub<noreply@github.com>:
Correct the learning rate as per the code snippet
---
Commit 03da61134 authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Update ir_array.cc
---
Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com>
Committed by GitHub<noreply@github.com>:
Another try
---
Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Transpose to go through Eigen ()

---
Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Registers RGBToHSV and HSVToRGB () ()

* [OpenCL] Added RGBToHSV and HSVToRGB

* Aligning '\'
---
Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
[OpenCL] Registers AdjustContrastv2 ()

* [OpenCL] Registers AdjustContrastv2 ()

* [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL ()

* [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL

* simplified to #ifndef

* Changed to "#if GOOGLE_CUDA"

* Update adjust_contrast_op_benchmark_test.cc

* Added comments

---
Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Make TransferBufferToInFeed not virual so it compiles.

---
Commit e89f04d80 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix calling Literal member functions.

---
Commit 15a8df724 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix mac build
clone from meheff's change:
[XLA] Change return type of DeviceAssignment::Deserialize to fix build
breakage on mac.
The mac build had the following error:

error: incomplete type 'xla::DeviceAssignment' used in type trait
expression

This was due to a static method returning a StatusOr<DeviceAssignment>
inside of the definition of DeviceAssignment.

---
Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Replace LiteralUtil to Literal in compiler/plugin/executor

---
Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com>
Committed by Guenther Schmuelling<guschmue@microsoft.com>:
expand inline for debug builds to limit number of symbols

---
Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Fix visibility error for contrib/remote_fused_graph/pylib/BUILD.

---
Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org>
Committed by Mark Neumann<markn@allenai.org>:
fix initial test values to avoid numerical instability

---
Commit b58d98353 authored by sj6077<epik03sj@gmail.com>
Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>:
Fixes of AutoParallel bug ()

* Fix the bug that auto_parallel could replicate variable snapshot name

* Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item

* remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel

---
Commit a286b7db8 authored by Yifei Feng<yifeif@google.com>
Committed by Yifei Feng<yifeif@google.com>:
Make debug_test slice integer.

---
Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com>
Committed by GitHub<noreply@github.com>:
Fixed path to seq2seq.py and minor formatting
---
Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com>
Committed by Anish Shah<shah.anish07@gmail.com>:
Improve docs for tf.nn.depthwise_conv2d_native

---
Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com>
Committed by Yong Tang<yong.tang.github@outlook.com>:
Fix mismatched delete in mkl_tfconv_op.cc

This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc

(the file went through clang-format so there are some additional
changes)

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---
Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com>
Committed by Danny Goodman<goodman.danny@gmail.com>:
fix error format

---
Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible).

PiperOrigin-RevId: 159649743

---
Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
[XLA] Add tests for select ops and while loops that produce tuples that contain predicates.

PiperOrigin-RevId: 159645900

---
Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Use C API to implement Operation.name property

This name property is used in many existing tests including those that
already run with C API enabled (math_ops_test, framework_ops_test,
session_test, session_partial_run_test, math_ops_test_gpu, etc).

PiperOrigin-RevId: 159645767

---
Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error.

PiperOrigin-RevId: 159645612

---
Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org>
Committed by TensorFlower Gardener<gardener@tensorflow.org>:
BEGIN_PUBLIC
Automated g4 rollback of changelist 159570549

PiperOrigin-RevId: 160182040
2017-06-26 12:57:46 -07:00

1167 lines
52 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "external/llvm/include/llvm/IR/BasicBlock.h"
#include "external/llvm/include/llvm/IR/Instructions.h"
#include "external/llvm/include/llvm/IR/Intrinsics.h"
#include "external/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
using llvm_ir::IrArray;
using llvm_ir::SetToFirstInsertPoint;
StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
} else {
return operand_value->getType()->isIntegerTy()
? EmitIntegerUnaryOp(op, operand_value)
: EmitFloatUnaryOp(op, operand_value);
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsIntegralType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::IsIntegralType(to_type)) {
return ir_builder_->CreateIntCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_),
primitive_util::IsSignedIntegralType(to_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (primitive_util::IsSignedIntegralType(from_type)) {
return ir_builder_->CreateSIToFP(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
}
if (primitive_util::IsUnsignedIntegralType(from_type)) {
return ir_builder_->CreateUIToFP(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
}
}
return Unimplemented("conversion from primitive type %s to %s",
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str());
}
case HloOpcode::kAbs: {
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
if (is_signed) {
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
ir_builder_);
auto zero = llvm::ConstantInt::get(type, 0);
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
return ir_builder_->CreateSelect(cmp, operand_value,
ir_builder_->CreateNeg(operand_value));
} else {
return operand_value;
}
}
case HloOpcode::kSign: {
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
ir_builder_);
auto zero = llvm::ConstantInt::get(type, 0);
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
if (is_signed) {
auto ashr = ir_builder_->CreateAShr(operand_value,
type->getIntegerBitWidth() - 1);
return ir_builder_->CreateSelect(cmp, zero,
ir_builder_->CreateOr(ashr, 1));
} else {
return ir_builder_->CreateSelect(cmp, zero,
llvm::ConstantInt::get(type, 1));
}
}
case HloOpcode::kNegate:
return ir_builder_->CreateNeg(operand_value);
case HloOpcode::kLogicalNot:
// It is not sufficient to just call CreateNot() here because a PRED is
// represented as an i8 and the truth value is stored only in the bottom
// bit.
return ir_builder_->CreateZExt(
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
operand_value, ir_builder_->getInt1Ty())),
llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
default:
return Unimplemented("unary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
switch (op->opcode()) {
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
CHECK(primitive_util::IsFloatingPointType(from_type));
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::IsFloatingPointType(to_type)) {
return ir_builder_->CreateFPCast(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
return ir_builder_->CreateFPToSI(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
return ir_builder_->CreateFPToUI(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
PrimitiveType_Name(from_type).c_str(),
PrimitiveType_Name(to_type).c_str());
}
case HloOpcode::kExp:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value},
{operand_value->getType()},
ir_builder_);
case HloOpcode::kLog:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value},
{operand_value->getType()},
ir_builder_);
case HloOpcode::kCos:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value},
{operand_value->getType()},
ir_builder_);
case HloOpcode::kFloor:
return llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
ir_builder_);
case HloOpcode::kCeil:
return llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
ir_builder_);
case HloOpcode::kAbs:
return llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
ir_builder_);
case HloOpcode::kSign: {
// TODO(b/32151903): Ensure consistent sign behavior for -0.0
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
return ir_builder_->CreateSelect(
oeq, zero,
ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
// (x == x) && abs(x) != inf
auto type = operand_value->getType();
auto equal_self =
ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
auto infinity = llvm::ConstantFP::getInfinity(type);
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
return ir_builder_->CreateZExt(
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
}
case HloOpcode::kNegate:
return ir_builder_->CreateFNeg(operand_value);
default:
return Unimplemented("unary floating-point op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
return lhs_value->getType()->isIntegerTy()
? EmitIntegerBinaryOp(op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(
op->operand(0)->shape().element_type()))
: EmitFloatBinaryOp(op, lhs_value, rhs_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
switch (op->opcode()) {
case HloOpcode::kAdd:
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
return ir_builder_->CreateFSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
return ir_builder_->CreateFMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
return ir_builder_->CreateFDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
return ir_builder_->CreateFRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
//
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
case HloOpcode::kPower:
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow,
{lhs_value, rhs_value},
{lhs_value->getType()}, ir_builder_);
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
{lhs_value, rhs_value},
{lhs_value->getType()}, ir_builder_);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
{lhs_value, rhs_value},
{lhs_value->getType()}, ir_builder_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) const {
if (prim_type != F32) {
return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
}
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
};
auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
coefficients.pop_front();
for (float coefficient : coefficients) {
p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
getFloat(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
// The approximation has the form:
// w = log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
// } else {
// w = sqrt(w) - 3
// p = sum_{i=1}^n gq[i]*w^i
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
logf_fn,
{ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
ir_builder_->getFloatTy(), "p.addr", ir_builder_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
"w_less_than_five", ir_builder_);
// Handle true BB.
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
{
llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
tensorflow::gtl::ArraySlice<float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
llvm::Value* p = multiply_add(lq, lw);
ir_builder_->CreateStore(p, p_addr);
}
// Handle false BB.
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
{
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
llvm::Value* gw = ir_builder_->CreateFSub(
ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
tensorflow::gtl::ArraySlice<float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
llvm::Value* p = multiply_add(gq, gw);
ir_builder_->CreateStore(p, p_addr);
}
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
llvm::Value* p = ir_builder_->CreateLoad(p_addr);
return ir_builder_->CreateFMul(p, x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
PrimitiveType prim_type, llvm::Value* value) const {
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_);
auto one = llvm::ConstantFP::get(type, 1.0);
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
const HloInstruction* hlo, llvm::Value* x) const {
if (hlo->operand(0)->shape().element_type() != F32) {
return Unimplemented("reduce-precision only implemented for F32");
}
// Integer and float types for casting and constant generation.
llvm::Type* float_type = x->getType();
llvm::IntegerType* int_type = ir_builder_->getInt32Ty();
// Cast the input value to an integer for bitwise manipulation.
llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type);
if (hlo->mantissa_bits() < 23) {
// Last remaining mantissa bit.
const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits());
// Compute rounding bias for round-to-nearest with ties to even. This is
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr(
ir_builder_->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
(23 - hlo->mantissa_bits()));
llvm::Value* x_rounding_bias = ir_builder_->CreateAdd(
x_last_mantissa_bit,
llvm::ConstantInt::get(int_type, base_rounding_bias));
// Add rounding bias, and mask out truncated bits. Note that the case
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias);
x_as_int = ir_builder_->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
}
if (hlo->exponent_bits() < 8) {
// Masks for f32 values.
const uint32_t f32_sign_bit_mask = 1u << 31;
const uint32_t f32_exp_bits_mask = 0xffu << 23;
// An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
// significant bit -- is equal to 1.0f for all exponent sizes. Adding
// 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
// size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
// exponent (corresponding to 0.0f).
//
// Thus, the f32 exponent corresponding to the highest non-infinite
// exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
// exponent corresponding to the lowest exponent for a bit size of n is
// (2^7-1) - 2^(n-1)-1.
//
// Note that we have already checked that exponents_bits >= 1.
const uint32_t f32_exponent_bias = (1 << 7) - 1;
const uint32_t reduced_exponent_bias =
(1 << (hlo->exponent_bits() - 1)) - 1;
const uint32_t reduced_max_exponent =
f32_exponent_bias + reduced_exponent_bias;
const uint32_t reduced_min_exponent =
f32_exponent_bias - reduced_exponent_bias;
// Do we overflow or underflow?
llvm::Value* x_exponent = ir_builder_->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
llvm::Value* x_overflows = ir_builder_->CreateICmpUGT(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
llvm::Value* x_underflows = ir_builder_->CreateICmpULE(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
// Compute appropriately-signed values of zero and infinity.
llvm::Value* x_signed_zero = ir_builder_->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
llvm::Value* x_signed_inf = ir_builder_->CreateOr(
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int);
x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int);
}
// Cast the result back to a floating-point type.
llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type);
// Correct result for NaN inputs.
//
// The exponent handling will "normalize" NaN values to infinities, which is
// undesirable (except in the case with no mantissa bits, in which case it
// is mandatory). This logic also handles cases where mantissa-rounding
// causes a NaN's mantissa to overflow into the exponent bits, which would
// otherwise create an erroneous zero value.
//
// If the fast-math flags are set to assume no NaNs, the comparison is likely
// to be optimized away, so there's no point in even emitting it.
if (!ir_builder_->getFastMathFlags().noNaNs()) {
llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x);
if (hlo->mantissa_bits() > 0) {
result = ir_builder_->CreateSelect(x_is_nan, x, result);
} else {
result = ir_builder_->CreateSelect(
x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
}
}
return result;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const {
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
return ir_builder_->CreateAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
return ir_builder_->CreateSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
return ir_builder_->CreateMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
: ir_builder_->CreateUDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
: ir_builder_->CreateURem(lhs_value, rhs_value);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
rhs_value, ir_builder_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kMinimum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
case HloOpcode::kMaximum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
case HloOpcode::kLogicalAnd:
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kLogicalOr:
return ir_builder_->CreateOr(lhs_value, rhs_value);
default:
return Unimplemented("binary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
int64 operand_no) const {
CHECK(hlo.IsElementwise())
<< "HLO " << hlo.ToString() << " is not elementwise.";
const Shape& operand_shape = hlo.operand(operand_no)->shape();
// If the operand is scalar, the source index is always {}.
if (ShapeUtil::IsScalar(operand_shape)) {
return llvm_ir::IrArray::Index();
}
// If no implicit broadcast is needed for this operand, returns the target
// index as the source index.
if (ShapeUtil::Compatible(operand_shape, hlo.shape())) {
return target_index;
}
// If implicit broadcast is needed, the source dimensions that are broadcast
// have index 0.
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
llvm_ir::IrArray::Index source_index;
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
source_index.push_back(target_index[i]);
} else {
CHECK_EQ(1, operand_shape.dimensions(i));
source_index.push_back(ir_builder_->getInt64(0));
}
}
return source_index;
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
const {
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
llvm::Type* param_ir_type =
llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_);
// Same values as PCG library
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
llvm::Value* multiplier = ir_builder_->getInt(
llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
llvm::Value* increment = ir_builder_->getInt(
llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
auto random_value = [hlo]() {
const HloModule* module =
hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent()
: hlo->parent()->parent();
return module->RandomNew64();
};
// Seed each RNG emitter with a new 64-bit seed from the HloModule. If the
// compilation order is deterministic (i.e., RandomNew64 invocation order is
// deterministic), then the order of RNG is deterministic for a given seed and
// hence tests will be deterministic.
// If the user provides a global seed instruction then we only use 64-bits of
// the host's random number generator to seed the 128 bit value with the other
// 64-bits is due to a user specified global seed instruction.
// Create a GlobalVariable to maintain state between invocations. There is a
// bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit
// values.
llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
/*M=*/*module_,
/*Ty=*/ir_builder_->getInt64Ty(),
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/ir_builder_->getInt64(random_value()),
/*Name=*/"state_ptr0");
uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
: random_value();
llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
/*M=*/*module_,
/*Ty=*/ir_builder_->getInt64Ty(),
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/ir_builder_->getInt64(graph_seed),
/*Name=*/"state_ptr1");
// We want each thread to use its own stream, so we modify the increment per
// thread. We want the increment to remain odd, so we shift the thread id left
// 1 and add it to the increment.
increment = ir_builder_->CreateAdd(increment,
ir_builder_->CreateShl(EmitThreadId(), 1));
// PCG-XSL-RR algorithm
// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
// state = multiplier * state + increment
// return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
// where ">>>" is bitwise rotation
auto get_next_i64 = [=]() {
llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
ir_builder_->CreateLoad(state_ptr0, "state0"),
ir_builder_->getInt128Ty());
llvm::Value* state1 = ir_builder_->CreateShl(
ir_builder_->CreateZExtOrTrunc(
ir_builder_->CreateLoad(state_ptr1, "state1"),
ir_builder_->getInt128Ty()),
64);
llvm::Value* state = ir_builder_->CreateOr(state0, state1);
llvm::Value* updated = ir_builder_->CreateAdd(
ir_builder_->CreateMul(state, multiplier), increment);
ir_builder_->CreateStore(
ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
state_ptr0);
ir_builder_->CreateStore(
ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
ir_builder_->getInt64Ty()),
state_ptr1);
return llvm_ir::CreateRor(
ir_builder_->CreateTrunc(
ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
ir_builder_->getInt64Ty()),
ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
ir_builder_->getInt64Ty()),
ir_builder_);
};
auto get_next_uniform_float = [=]() {
return ir_builder_->CreateFDiv(
ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
llvm::ConstantFP::get(param_ir_type, 0x1p64));
};
return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
switch (hlo->random_distribution()) {
case RNG_UNIFORM: {
TF_ASSIGN_OR_RETURN(llvm::Value * p,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * q,
operand_to_generator.at(hlo->operand(1))(index));
if (primitive_util::IsFloatingPointType(param_prim_type)) {
return ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
get_next_uniform_float()),
p);
} else {
auto r = ir_builder_->CreateSub(q, p);
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
{param_ir_type}, ir_builder_);
auto in_block = ir_builder_->GetInsertBlock();
// A terminator should be present iff we're emitting code
// into the middle (as opposed to the end) of a basic block.
CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(),
in_block->getTerminator() == nullptr);
llvm::BasicBlock* body_block;
llvm::BasicBlock* out_block;
if (ir_builder_->GetInsertPoint() == in_block->end()) {
body_block =
llvm_ir::CreateBasicBlock(nullptr, "rng_body", ir_builder_);
out_block =
llvm_ir::CreateBasicBlock(nullptr, "rng_out", ir_builder_);
llvm::BranchInst::Create(body_block, in_block);
} else {
body_block = in_block->splitBasicBlock(
ir_builder_->GetInsertPoint(), "rng_body");
out_block = body_block->splitBasicBlock(
ir_builder_->GetInsertPoint(), "rng_out");
body_block->getTerminator()->eraseFromParent();
}
SetToFirstInsertPoint(body_block, ir_builder_);
auto random = ir_builder_->CreateAnd(
ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
leading_zeros));
llvm::BranchInst::Create(out_block, body_block,
ir_builder_->CreateICmpULT(random, r),
body_block);
SetToFirstInsertPoint(out_block, ir_builder_);
return ir_builder_->CreateAdd(
p, ir_builder_->CreateSelect(
ir_builder_->CreateICmpEQ(p, q),
llvm::ConstantInt::get(param_ir_type, 0), random));
}
}
case RNG_NORMAL: {
TF_ASSIGN_OR_RETURN(llvm::Value * m,
operand_to_generator.at(hlo->operand(0))(index));
TF_ASSIGN_OR_RETURN(llvm::Value * s,
operand_to_generator.at(hlo->operand(1))(index));
TF_ASSIGN_OR_RETURN(
llvm::Value * r,
EmitErfcInv(param_prim_type,
ir_builder_->CreateFMul(
llvm::ConstantFP::get(param_ir_type, 2.0),
get_next_uniform_float())));
return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
}
case RNG_BERNOULLI: {
TF_ASSIGN_OR_RETURN(llvm::Value * p,
operand_to_generator.at(hlo->operand(0))(index));
return ir_builder_->CreateZExt(
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_));
}
default:
return InvalidArgument(
"unhandled distribution %s",
RandomDistribution_Name(hlo->random_distribution()).c_str());
}
};
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
const {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kFloor:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNegate:
case HloOpcode::kSign:
case HloOpcode::kTanh:
case HloOpcode::kLogicalNot:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
return EmitUnaryOp(hlo, operand_value);
};
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kLogicalAnd:
case HloOpcode::kLogicalOr:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);
const HloInstruction* rhs = hlo->operand(1);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
operand_to_generator.at(lhs)(
ElementwiseSourceIndex(index, *hlo, 0)));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
operand_to_generator.at(rhs)(
ElementwiseSourceIndex(index, *hlo, 1)));
return EmitBinaryOp(hlo, lhs_value, rhs_value);
};
case HloOpcode::kSelect:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
operand_to_generator.at(hlo->operand(1))(
ElementwiseSourceIndex(index, *hlo, 1)));
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
return ir_builder_->CreateSelect(
ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
on_true_value, on_false_value);
};
case HloOpcode::kClamp:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
operand_to_generator.at(hlo->operand(1))(
ElementwiseSourceIndex(index, *hlo, 1)));
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(
ElementwiseSourceIndex(index, *hlo, 0)));
return EmitReducePrecision(hlo, operand_value);
};
case HloOpcode::kConcatenate:
return [this, hlo, &operand_to_generator](
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
llvm::PHINode* output = ir_builder_->CreatePHI(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
hlo->operands().size());
llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
auto prior_insert_point = ir_builder_->GetInsertPoint();
llvm::BasicBlock* exit_block =
init_block->splitBasicBlock(output, "concat_merge");
ir_builder_->SetInsertPoint(init_block);
init_block->getTerminator()->eraseFromParent();
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
++operand_idx) {
const HloInstruction* operand = hlo->operand(operand_idx);
auto true_block = llvm_ir::CreateBasicBlock(
exit_block, tensorflow::strings::StrCat(
"concat_index_from_operand", operand_idx),
ir_builder_);
auto false_block = llvm_ir::CreateBasicBlock(
exit_block, tensorflow::strings::StrCat(
"concat_index_not_from_operand", operand_idx),
ir_builder_);
auto concat_dim_size =
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
operand->shape().dimensions(concat_dim));
ir_builder_->CreateCondBr(
ir_builder_->CreateICmpULT(source_index[concat_dim],
concat_dim_size),
true_block, false_block);
// Create the terminator of the true block before calling operand
// generators, because they require non-degenerate basic blocks.
ir_builder_->SetInsertPoint(
llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
TF_ASSIGN_OR_RETURN(llvm::Value * value,
operand_to_generator.at(operand)(source_index));
output->addIncoming(value, ir_builder_->GetInsertBlock());
// Subtract the size of the concat dimension of the current operand
// from the source index.
ir_builder_->SetInsertPoint(false_block);
source_index[concat_dim] =
ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
}
ir_builder_->CreateUnreachable();
ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
return output;
};
case HloOpcode::kReverse:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
auto source_index = target_index;
for (int64 dim : hlo->dimensions()) {
source_index[dim] = ir_builder_->CreateSub(
llvm::ConstantInt::get(target_index[dim]->getType(),
hlo->shape().dimensions(dim) - 1),
target_index[dim]);
}
return operand_to_generator.at(operand)(source_index);
};
case HloOpcode::kBroadcast:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
// The `dimensions` member of the broadcast instruction maps from
// input dimensions to output dimensions.
const HloInstruction* operand = hlo->operand(0);
int64 rank = ShapeUtil::Rank(operand->shape());
IrArray::Index source_index(rank);
for (int64 i = 0; i < rank; ++i) {
source_index[i] = target_index[hlo->dimensions(i)];
}
return operand_to_generator.at(operand)(source_index);
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index(index.size());
for (int i = 0; i < index.size(); ++i) {
int64 stride = hlo->slice_stride(i);
if (stride != 1) {
sliced_index[i] = ir_builder_->CreateAdd(
ir_builder_->CreateMul(
index[i], llvm::ConstantInt::get(index[i]->getType(),
stride)),
llvm::ConstantInt::get(index[i]->getType(),
hlo->slice_starts(i)));
} else {
sliced_index[i] = ir_builder_->CreateAdd(
index[i],
llvm::ConstantInt::get(index[i]->getType(),
hlo->slice_starts(i)));
}
}
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
case HloOpcode::kDynamicSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
llvm_ir::IrArray::Index slice_start_index(rank);
for (int64 i = 0; i < rank; ++i) {
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
TF_ASSIGN_OR_RETURN(
llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(1))(dim_index));
slice_start_index[i] = start_index_value;
}
llvm_ir::IrArray::Index input_index(rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = (start_index + offset_index) % dim_size
// Security note: this is the code that keeps the indices in-bounds.
llvm::Value* dim_size = llvm::ConstantInt::get(
index[i]->getType(), input_hlo->shape().dimensions(i));
llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
slice_start_index[i], index[i]->getType());
input_index[i] = ir_builder_->CreateURem(
ir_builder_->CreateAdd(start_index, index[i]), dim_size);
}
return operand_to_generator.at(input_hlo)(input_index);
};
case HloOpcode::kDynamicUpdateSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* input_hlo = hlo->operand(0);
const HloInstruction* update_hlo = hlo->operand(1);
const HloInstruction* start_hlo = hlo->operand(2);
// Calculate slice start/end indices.
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
llvm_ir::IrArray::Index slice_start_index(rank);
llvm_ir::IrArray::Index slice_limit_index(rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR to read dynamic start indices from 'start_hlo'.
llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(start_hlo)(dim_index));
slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
start_index_value, index[i]->getType());
// Emit IR to compute: slice_limit_index = start_index + update_dim
// NOTE: Although 'start_indices' is dynamic and could be
// out-of-range, we do not compute 'slice_limit_index' mod input dim
// size here, because subsequent array index calculations will be
// computed mod input dim size for safety.
llvm::Value* update_dim_size = llvm::ConstantInt::get(
index[i]->getType(), update_hlo->shape().dimensions(i));
slice_limit_index[i] =
ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
}
// Check if 'index' intersects start/end indices.
llvm::Value* slice_intersection =
llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1);
for (int64 i = 0; i < rank; ++i) {
// Check that index[i] >= slice_start_index[i].
slice_intersection = ir_builder_->CreateAnd(
slice_intersection,
ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
"slice_intersection");
// Check that index[i] < slice_limit_index[i].
slice_intersection = ir_builder_->CreateAnd(
slice_intersection,
ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
"slice_intersection");
}
// Emit:
// if (slice_intersection) -> return data from 'update'.
// else -> return data from 'index'.
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
"ret_value_addr", ir_builder_);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
slice_intersection, "slice_intersection", ir_builder_);
// Handle true BB.
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(rank);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* update_dim_size = llvm::ConstantInt::get(
index[i]->getType(), update_hlo->shape().dimensions(i));
// NOTE: Subtraction will be positive due to bounds checking above.
update_index[i] = ir_builder_->CreateURem(
ir_builder_->CreateSub(index[i], slice_start_index[i]),
update_dim_size);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
ir_builder_->CreateStore(true_value, ret_value_addr);
// Handle false BB.
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
ir_builder_->CreateStore(false_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
return ir_builder_->CreateLoad(ret_value_addr);
};
case HloOpcode::kReshape:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
hlo->shape(), operand->shape(), ir_builder_));
};
case HloOpcode::kTranspose:
return [this, hlo,
&operand_to_generator](const IrArray::Index& target_index) {
return operand_to_generator.at(hlo->operand(0))(
target_index.SourceIndexOfTranspose(
hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
ir_builder_));
};
case HloOpcode::kRng:
return MakeRngElementGenerator(hlo, operand_to_generator);
case HloOpcode::kPad:
return [=, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
auto index = padded_index;
llvm::Value* in_bounds = ir_builder_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
auto index_typed_const = [=](int64 n) {
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
index[i] = ir_builder_->CreateSub(
index[i], index_typed_const(pad_dim.edge_padding_low()));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
"in_bounds");
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpEQ(
index_typed_const(0),
ir_builder_->CreateURem(
index[i],
index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
index[i] = ir_builder_->CreateSDiv(
index[i], index_typed_const(pad_dim.interior_padding() + 1));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSLT(
index[i],
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
"in_bounds");
}
// if (in_bounds) {
// ret_value = operand0[index]; // source
// } else {
// ret_value = *operand1; // padding
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
"pad_result_addr", ir_builder_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
ir_builder_->CreateStore(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))({}));
ir_builder_->CreateStore(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
return ir_builder_->CreateLoad(ret_value_addr);
};
default:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str());
};
}
}
} // namespace xla