[XLA] More readable emitted LLVM code.

This commit is contained in:
Frederic Bastien 2021-02-01 06:28:45 -08:00
parent 164f5b2a5b
commit fb9882f089
14 changed files with 118 additions and 81 deletions

View File

@ -35,7 +35,8 @@ namespace cpu {
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {

View File

@ -37,7 +37,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs,
absl::string_view name = "") override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;

View File

@ -828,15 +828,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return FAdd(lhs_value, rhs_value);
return FAdd(lhs_value, rhs_value, op->name());
case HloOpcode::kSubtract:
return FSub(lhs_value, rhs_value);
return FSub(lhs_value, rhs_value, op->name());
case HloOpcode::kMultiply:
return FMul(lhs_value, rhs_value);
return FMul(lhs_value, rhs_value, op->name());
case HloOpcode::kDivide:
return FDiv(lhs_value, rhs_value);
return FDiv(lhs_value, rhs_value, op->name());
case HloOpcode::kRemainder:
return FRem(lhs_value, rhs_value);
return FRem(lhs_value, rhs_value, op->name());
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@ -848,32 +848,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
}
}
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
return EmitFloatMax(lhs_value, rhs_value, op->name());
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
return EmitFloatMin(lhs_value, rhs_value, op->name());
case HloOpcode::kPower:
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
return EmitPow(op->shape().element_type(), lhs_value, rhs_value, op->name());
case HloOpcode::kAtan2:
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value, op->name());
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()));
@ -1314,13 +1314,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
@ -1404,9 +1406,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
llvm::Value* value,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
{value->getType()}, b_, name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
@ -1438,9 +1441,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
{lhs->getType()}, b_, name);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
@ -1458,7 +1462,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return Unimplemented("atan2");
}

View File

@ -105,10 +105,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* rhs_value);
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name = "");
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name = "");
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed);
@ -117,7 +119,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
bool is_signed);
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name = "");
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value);
@ -141,13 +144,15 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value);
llvm::Value* value,
absl::string_view name = "");
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value);
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name = "");
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value);

View File

@ -78,7 +78,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name) {
// Device functions dont have f16 math functions, so we convert the operands
// to f32 before calling the function and then convert the result back to f16.
bool cast_result_to_fp16 = false;
@ -109,7 +110,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
const string& munged_callee =
ObtainDeviceFunctionName(funcid, output_type, b());
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
converted_input_types, output_type, name)
.ValueOrDie();
if (cast_result_to_fp16) {
result = FPCast(result, b()->getHalfTy());
@ -142,7 +143,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
@ -154,7 +156,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
return EmitDeviceFunctionCall(
callee_name, operands, input_types, output_type,
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
@ -222,7 +224,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
llvm::Value* value,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
prim_type);
}
@ -235,9 +238,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
{prim_type, prim_type}, prim_type);
{prim_type, prim_type}, prim_type, name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
@ -254,9 +258,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
{prim_type, prim_type}, prim_type);
{prim_type, prim_type}, prim_type, name);
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,

View File

@ -65,7 +65,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* value) override;
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value) override;
llvm::Value* value,
absl::string_view name = "") override;
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value) override;
@ -77,10 +78,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* value) override;
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs, absl::string_view name = "") override;
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs, absl::string_view name = "") override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
@ -118,13 +119,15 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
// return value of the function.
StatusOr<llvm::Value*> EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name = "");
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name = "");
const HloModuleConfig& hlo_module_config_;

View File

@ -91,7 +91,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_, operand->name());
};
}
return EmitTargetElementLoop(
@ -688,7 +688,8 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
fused_emitter->BindGenerator(
fusion->fused_parameter(i),
[this, operand, fusion](llvm_ir::IrArray::Index index) {
return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
return GetIrArray(*operand, *fusion).EmitReadArrayElement(
index, &b_, operand->name());
});
}
}

View File

@ -1900,10 +1900,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
GetNestedComputer());
FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
auto fused_operand = fused_computation->parameter_instruction(i);
operand_fused_emitter.BindGenerator(
fused_computation->parameter_instruction(i),
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(index, &b_);
fused_operand,
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(
index, &b_, fused_operand->name());
});
}
TF_ASSIGN_OR_RETURN(
@ -1942,10 +1944,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
GetNestedComputer());
FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
for (int i = 0; i < fused_computation->num_parameters(); i++) {
auto fused_operand = fused_computation->parameter_instruction(i);
scatter_fused_emitter.BindGenerator(
fused_computation->parameter_instruction(i),
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(index, &b_);
fused_operand,
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(
index, &b_, fused_operand->name());
});
}
@ -2049,10 +2053,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
/*is_fusion=*/true));
for (int i = 0; i < fused_computation->num_parameters(); i++) {
auto fused_operand = fused_computation->parameter_instruction(i);
fused_emitter.BindGenerator(
fused_computation->parameter_instruction(i),
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(index, &b_);
fused_operand,
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
return ir_arrays[i].EmitReadArrayElement(
index, &b_, fused_operand->name());
});
}
@ -4165,8 +4171,10 @@ void IrEmitterUnnested::EmitTileElementForFusion(
};
} else {
auto array = operand_arrays[i];
gen = [this, array](llvm_ir::IrArray::Index index) {
return array.EmitReadArrayElement(index, &b_);
auto name = fused_computation->parameter_instruction(i)->name();
gen = [this, array, name](llvm_ir::IrArray::Index index) {
return array.EmitReadArrayElement(
index, &b_, name);
};
}
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
@ -5621,10 +5629,12 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
for (int i = 0; i < fused_computation->num_parameters(); i++) {
auto ir_array = ir_arrays[i];
auto fused_operand = fused_computation->parameter_instruction(i);
fused_emitter->BindGenerator(
fused_computation->parameter_instruction(i),
[this, ir_array](llvm_ir::IrArray::Index index) {
return ir_array.EmitReadArrayElement(index, &b_);
fused_operand,
[this, ir_array, fused_operand](llvm_ir::IrArray::Index index) {
return ir_array.EmitReadArrayElement(
index, &b_, fused_operand->name());
});
}
result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(

View File

@ -194,7 +194,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::Span<const llvm::Attribute::AttrKind> attributes,
llvm::IRBuilder<>* b) {
llvm::IRBuilder<>* b, absl::string_view name) {
std::vector<llvm::Type*> ir_input_types;
llvm::Module* module = b->GetInsertBlock()->getModule();
for (PrimitiveType input_type : input_types) {
@ -217,7 +217,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
callee->addFnAttr(attribute);
}
return b->CreateCall(callee, llvm_ir::AsArrayRef(operands));
return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
}
llvm::CallInst* EmitCallToTargetIntrinsic(

View File

@ -69,7 +69,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
const std::string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
absl::Span<const llvm::Attribute::AttrKind> attributes,
llvm::IRBuilder<>* b);
llvm::IRBuilder<>* b, absl::string_view name = "");
// Emits a call to the specified target intrinsic with the given operands.
// Overloaded intrinsics (for example, "minnum") must include a type

View File

@ -102,7 +102,7 @@ Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
global,
llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
return IrArray(shape_constant, constant->shape())
.EmitReadArrayElement(index, b_);
.EmitReadArrayElement(index, b_, constant->name());
};
return Status::OK();

View File

@ -504,7 +504,7 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
bool use_linear_index) const {
llvm::Value* element_address =
EmitArrayElementAddress(index, b, name, use_linear_index);
llvm::LoadInst* load = b->CreateLoad(element_address);
llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
AnnotateLoadStoreInstructionWithMetadata(load);
return load;
}

View File

@ -83,36 +83,39 @@ string DumpModuleToString(const llvm::Module& module) {
llvm::CallInst* EmitCallToIntrinsic(
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
absl::string_view name) {
llvm::Module* module = ModuleFromIRBuilder(b);
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, intrinsic_id, AsArrayRef(overloaded_types));
return b->CreateCall(intrinsic, AsArrayRef(operands));
return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
}
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
} else {
auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
}
}
@ -351,12 +354,12 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b) {
llvm::IRBuilder<>* b, absl::string_view name) {
llvm::Value* comparison_result;
if (lhs_value->getType()->isIntegerTy()) {
comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
} else {
comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
}
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
// arrays. So we extend it to i8 so that it's addressable.

View File

@ -103,17 +103,20 @@ string SanitizeFunctionName(string function_name);
// overloaded type.
llvm::CallInst* EmitCallToIntrinsic(
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b);
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
absl::string_view name = "");
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max);
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name = "");
// Emit float min. Emit minnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b, bool enable_fast_min_max);
llvm::IRBuilder<>* b, bool enable_fast_min_max,
absl::string_view name = "");
// Convenience methods for emitting a GEP instruction that indexes into a buffer
// (1-dimensional array), equivalent to array[index]. The type is automatically
@ -214,7 +217,7 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
// and then converts the result to i8 so that it is addressable.
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* b);
llvm::IRBuilder<>* b, absl::string_view name = "");
// Emits a call that logs the given value with the given tag as a prefix.
// The provided tag and value are passed to a runtime logging call that is