[XLA] More readable emitted LLVM code.
This commit is contained in:
parent
164f5b2a5b
commit
fb9882f089
@ -35,7 +35,8 @@ namespace cpu {
|
||||
|
||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name) {
|
||||
string function_name;
|
||||
bool cast_result_to_fp16 = false;
|
||||
switch (prim_type) {
|
||||
|
||||
@ -37,7 +37,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) override;
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name = "") override;
|
||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) override;
|
||||
|
||||
|
||||
@ -828,15 +828,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
case HloOpcode::kComplex:
|
||||
return EmitComposeComplex(op, lhs_value, rhs_value);
|
||||
case HloOpcode::kAdd:
|
||||
return FAdd(lhs_value, rhs_value);
|
||||
return FAdd(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kSubtract:
|
||||
return FSub(lhs_value, rhs_value);
|
||||
return FSub(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kMultiply:
|
||||
return FMul(lhs_value, rhs_value);
|
||||
return FMul(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kDivide:
|
||||
return FDiv(lhs_value, rhs_value);
|
||||
return FDiv(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kRemainder:
|
||||
return FRem(lhs_value, rhs_value);
|
||||
return FRem(lhs_value, rhs_value, op->name());
|
||||
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
|
||||
// comparisons always return false when one of the operands is NaN, whereas
|
||||
// unordered comparisons return true.
|
||||
@ -848,32 +848,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
switch (op->comparison_direction()) {
|
||||
case ComparisonDirection::kEq:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
case ComparisonDirection::kNe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
case ComparisonDirection::kLt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
case ComparisonDirection::kGt:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
case ComparisonDirection::kLe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
case ComparisonDirection::kGe:
|
||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
||||
rhs_value, b_);
|
||||
rhs_value, b_, op->name());
|
||||
}
|
||||
}
|
||||
case HloOpcode::kMaximum:
|
||||
return EmitFloatMax(lhs_value, rhs_value);
|
||||
return EmitFloatMax(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kMinimum:
|
||||
return EmitFloatMin(lhs_value, rhs_value);
|
||||
return EmitFloatMin(lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kPower:
|
||||
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
|
||||
return EmitPow(op->shape().element_type(), lhs_value, rhs_value, op->name());
|
||||
case HloOpcode::kAtan2:
|
||||
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
|
||||
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value, op->name());
|
||||
default:
|
||||
return Unimplemented("binary floating point op '%s'",
|
||||
HloOpcodeString(op->opcode()));
|
||||
@ -1314,13 +1314,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) {
|
||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
|
||||
llvm::Value* rhs_value,
|
||||
absl::string_view name) {
|
||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) {
|
||||
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
|
||||
llvm::Value* rhs_value,
|
||||
absl::string_view name) {
|
||||
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
|
||||
@ -1404,9 +1406,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
|
||||
llvm::Value* value) {
|
||||
llvm::Value* value,
|
||||
absl::string_view name) {
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
|
||||
{value->getType()}, b_);
|
||||
{value->getType()}, b_, name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
|
||||
@ -1438,9 +1441,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name) {
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
|
||||
{lhs->getType()}, b_);
|
||||
{lhs->getType()}, b_, name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
|
||||
@ -1458,7 +1462,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name) {
|
||||
return Unimplemented("atan2");
|
||||
}
|
||||
|
||||
|
||||
@ -105,10 +105,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
llvm::Value* rhs_value);
|
||||
|
||||
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
llvm::Value* rhs_value,
|
||||
absl::string_view name = "");
|
||||
|
||||
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
llvm::Value* rhs_value,
|
||||
absl::string_view name = "");
|
||||
|
||||
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
bool is_signed);
|
||||
@ -117,7 +119,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
bool is_signed);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
|
||||
llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* lhs, llvm::Value* rhs,
|
||||
absl::string_view name = "");
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
@ -141,13 +144,15 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
llvm::Value* value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
llvm::Value* value,
|
||||
absl::string_view name = "");
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
|
||||
llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* lhs, llvm::Value* rhs,
|
||||
absl::string_view name = "");
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value);
|
||||
|
||||
@ -78,7 +78,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
|
||||
absl::string_view name) {
|
||||
// Device functions dont have f16 math functions, so we convert the operands
|
||||
// to f32 before calling the function and then convert the result back to f16.
|
||||
bool cast_result_to_fp16 = false;
|
||||
@ -109,7 +110,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||
const string& munged_callee =
|
||||
ObtainDeviceFunctionName(funcid, output_type, b());
|
||||
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
|
||||
converted_input_types, output_type)
|
||||
converted_input_types, output_type, name)
|
||||
.ValueOrDie();
|
||||
if (cast_result_to_fp16) {
|
||||
result = FPCast(result, b()->getHalfTy());
|
||||
@ -142,7 +143,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
|
||||
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
|
||||
absl::string_view name) {
|
||||
// Binary math functions transform are of type [T] -> T.
|
||||
for (PrimitiveType input_type : input_types) {
|
||||
if (output_type != input_type) {
|
||||
@ -154,7 +156,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
|
||||
|
||||
return EmitDeviceFunctionCall(
|
||||
callee_name, operands, input_types, output_type,
|
||||
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
|
||||
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
||||
@ -222,7 +224,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
|
||||
llvm::Value* value) {
|
||||
llvm::Value* value,
|
||||
absl::string_view name) {
|
||||
return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
|
||||
prim_type);
|
||||
}
|
||||
@ -235,9 +238,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name) {
|
||||
return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
|
||||
{prim_type, prim_type}, prim_type);
|
||||
{prim_type, prim_type}, prim_type, name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
|
||||
@ -254,9 +258,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
llvm::Value* rhs,
|
||||
absl::string_view name) {
|
||||
return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
|
||||
{prim_type, prim_type}, prim_type);
|
||||
{prim_type, prim_type}, prim_type, name);
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
|
||||
@ -65,7 +65,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
llvm::Value* value) override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
|
||||
llvm::Value* value) override;
|
||||
llvm::Value* value,
|
||||
absl::string_view name = "") override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
|
||||
llvm::Value* value) override;
|
||||
@ -77,10 +78,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
llvm::Value* value) override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) override;
|
||||
llvm::Value* rhs, absl::string_view name = "") override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) override;
|
||||
llvm::Value* rhs, absl::string_view name = "") override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) override;
|
||||
@ -118,13 +119,15 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
// return value of the function.
|
||||
StatusOr<llvm::Value*> EmitDeviceMathCall(
|
||||
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
|
||||
absl::string_view name = "");
|
||||
|
||||
// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
|
||||
// Returns the IR value that represents the return value of the function.
|
||||
StatusOr<llvm::Value*> EmitMathCall(
|
||||
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
|
||||
absl::string_view name = "");
|
||||
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
|
||||
for (const HloInstruction* operand : hlo->operands()) {
|
||||
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
|
||||
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
|
||||
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_, operand->name());
|
||||
};
|
||||
}
|
||||
return EmitTargetElementLoop(
|
||||
@ -688,7 +688,8 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
|
||||
fused_emitter->BindGenerator(
|
||||
fusion->fused_parameter(i),
|
||||
[this, operand, fusion](llvm_ir::IrArray::Index index) {
|
||||
return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
|
||||
return GetIrArray(*operand, *fusion).EmitReadArrayElement(
|
||||
index, &b_, operand->name());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1900,10 +1900,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
auto fused_operand = fused_computation->parameter_instruction(i);
|
||||
operand_fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(index, &b_);
|
||||
fused_operand,
|
||||
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(
|
||||
index, &b_, fused_operand->name());
|
||||
});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1942,10 +1944,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
GetNestedComputer());
|
||||
FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
auto fused_operand = fused_computation->parameter_instruction(i);
|
||||
scatter_fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(index, &b_);
|
||||
fused_operand,
|
||||
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(
|
||||
index, &b_, fused_operand->name());
|
||||
});
|
||||
}
|
||||
|
||||
@ -2049,10 +2053,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
/*is_fusion=*/true));
|
||||
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
auto fused_operand = fused_computation->parameter_instruction(i);
|
||||
fused_emitter.BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(index, &b_);
|
||||
fused_operand,
|
||||
[this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) {
|
||||
return ir_arrays[i].EmitReadArrayElement(
|
||||
index, &b_, fused_operand->name());
|
||||
});
|
||||
}
|
||||
|
||||
@ -4165,8 +4171,10 @@ void IrEmitterUnnested::EmitTileElementForFusion(
|
||||
};
|
||||
} else {
|
||||
auto array = operand_arrays[i];
|
||||
gen = [this, array](llvm_ir::IrArray::Index index) {
|
||||
return array.EmitReadArrayElement(index, &b_);
|
||||
auto name = fused_computation->parameter_instruction(i)->name();
|
||||
gen = [this, array, name](llvm_ir::IrArray::Index index) {
|
||||
return array.EmitReadArrayElement(
|
||||
index, &b_, name);
|
||||
};
|
||||
}
|
||||
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
|
||||
@ -5621,10 +5629,12 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
|
||||
for (int i = 0; i < fused_computation->num_parameters(); i++) {
|
||||
auto ir_array = ir_arrays[i];
|
||||
auto fused_operand = fused_computation->parameter_instruction(i);
|
||||
fused_emitter->BindGenerator(
|
||||
fused_computation->parameter_instruction(i),
|
||||
[this, ir_array](llvm_ir::IrArray::Index index) {
|
||||
return ir_array.EmitReadArrayElement(index, &b_);
|
||||
fused_operand,
|
||||
[this, ir_array, fused_operand](llvm_ir::IrArray::Index index) {
|
||||
return ir_array.EmitReadArrayElement(
|
||||
index, &b_, fused_operand->name());
|
||||
});
|
||||
}
|
||||
result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(
|
||||
|
||||
@ -194,7 +194,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
|
||||
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
|
||||
absl::Span<const llvm::Attribute::AttrKind> attributes,
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::IRBuilder<>* b, absl::string_view name) {
|
||||
std::vector<llvm::Type*> ir_input_types;
|
||||
llvm::Module* module = b->GetInsertBlock()->getModule();
|
||||
for (PrimitiveType input_type : input_types) {
|
||||
@ -217,7 +217,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
|
||||
callee->addFnAttr(attribute);
|
||||
}
|
||||
|
||||
return b->CreateCall(callee, llvm_ir::AsArrayRef(operands));
|
||||
return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data());
|
||||
}
|
||||
|
||||
llvm::CallInst* EmitCallToTargetIntrinsic(
|
||||
|
||||
@ -69,7 +69,7 @@ llvm::CallInst* EmitDeviceFunctionCall(
|
||||
const std::string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_type, PrimitiveType output_type,
|
||||
absl::Span<const llvm::Attribute::AttrKind> attributes,
|
||||
llvm::IRBuilder<>* b);
|
||||
llvm::IRBuilder<>* b, absl::string_view name = "");
|
||||
|
||||
// Emits a call to the specified target intrinsic with the given operands.
|
||||
// Overloaded intrinsics (for example, "minnum") must include a type
|
||||
|
||||
@ -102,7 +102,7 @@ Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
|
||||
global,
|
||||
llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
|
||||
return IrArray(shape_constant, constant->shape())
|
||||
.EmitReadArrayElement(index, b_);
|
||||
.EmitReadArrayElement(index, b_, constant->name());
|
||||
};
|
||||
|
||||
return Status::OK();
|
||||
|
||||
@ -504,7 +504,7 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
|
||||
bool use_linear_index) const {
|
||||
llvm::Value* element_address =
|
||||
EmitArrayElementAddress(index, b, name, use_linear_index);
|
||||
llvm::LoadInst* load = b->CreateLoad(element_address);
|
||||
llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
|
||||
AnnotateLoadStoreInstructionWithMetadata(load);
|
||||
return load;
|
||||
}
|
||||
|
||||
@ -83,36 +83,39 @@ string DumpModuleToString(const llvm::Module& module) {
|
||||
|
||||
llvm::CallInst* EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
|
||||
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
|
||||
absl::string_view name) {
|
||||
llvm::Module* module = ModuleFromIRBuilder(b);
|
||||
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
|
||||
module, intrinsic_id, AsArrayRef(overloaded_types));
|
||||
return b->CreateCall(intrinsic, AsArrayRef(operands));
|
||||
return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
|
||||
}
|
||||
|
||||
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max,
|
||||
absl::string_view name) {
|
||||
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
|
||||
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
|
||||
} else {
|
||||
auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
|
||||
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
|
||||
auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
|
||||
return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
|
||||
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max,
|
||||
absl::string_view name) {
|
||||
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
|
||||
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
|
||||
} else {
|
||||
auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
|
||||
auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
|
||||
auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
|
||||
return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
|
||||
return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
|
||||
}
|
||||
}
|
||||
|
||||
@ -351,12 +354,12 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
|
||||
|
||||
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
|
||||
llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::IRBuilder<>* b, absl::string_view name) {
|
||||
llvm::Value* comparison_result;
|
||||
if (lhs_value->getType()->isIntegerTy()) {
|
||||
comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
|
||||
comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
|
||||
} else {
|
||||
comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
|
||||
comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
|
||||
}
|
||||
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
|
||||
// arrays. So we extend it to i8 so that it's addressable.
|
||||
|
||||
@ -103,17 +103,20 @@ string SanitizeFunctionName(string function_name);
|
||||
// overloaded type.
|
||||
llvm::CallInst* EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b);
|
||||
absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
|
||||
absl::string_view name = "");
|
||||
|
||||
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
|
||||
// fcmp+select otherwise
|
||||
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max);
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max,
|
||||
absl::string_view name = "");
|
||||
|
||||
// Emit float min. Emit minnum intrinsic is fast math is disabled, or
|
||||
// fcmp+select otherwise
|
||||
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max);
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max,
|
||||
absl::string_view name = "");
|
||||
|
||||
// Convenience methods for emitting a GEP instruction that indexes into a buffer
|
||||
// (1-dimensional array), equivalent to array[index]. The type is automatically
|
||||
@ -214,7 +217,7 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
|
||||
// and then converts the result to i8 so that it is addressable.
|
||||
llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
|
||||
llvm::Value* lhs, llvm::Value* rhs,
|
||||
llvm::IRBuilder<>* b);
|
||||
llvm::IRBuilder<>* b, absl::string_view name = "");
|
||||
|
||||
// Emits a call that logs the given value with the given tag as a prefix.
|
||||
// The provided tag and value are passed to a runtime logging call that is
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user