In fast-math mode emit a tanh that has a faster min/max.
PiperOrigin-RevId: 164943597
This commit is contained in:
parent
87605f3d6a
commit
c0f9b0a91e
@ -100,7 +100,7 @@ operator()(llvm::Module& module) const {
|
||||
|
||||
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
|
||||
|
||||
runtime::RewriteIRRuntimeFunctions(&module);
|
||||
runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_);
|
||||
|
||||
// Buffer for holding machine code prior to constructing the ObjectFile.
|
||||
llvm::SmallVector<char, 0> stream_buffer;
|
||||
|
@ -42,7 +42,7 @@ class CompilerFunctor {
|
||||
|
||||
explicit CompilerFunctor(
|
||||
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
|
||||
int opt_level, bool optimize_for_size,
|
||||
int opt_level, bool optimize_for_size, bool enable_fast_math,
|
||||
const VectorIntrinsics& available_intrinsics,
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
|
||||
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
|
||||
@ -50,6 +50,7 @@ class CompilerFunctor {
|
||||
disassembler_(CHECK_NOTNULL(disassembler)),
|
||||
opt_level_(opt_level),
|
||||
optimize_for_size_(optimize_for_size),
|
||||
enable_fast_math_(enable_fast_math),
|
||||
available_intrinsics_(available_intrinsics),
|
||||
pre_optimization_hook_(pre_optimization_hook),
|
||||
post_optimization_hook_(post_optimization_hook) {}
|
||||
@ -72,6 +73,7 @@ class CompilerFunctor {
|
||||
const Disassembler* disassembler_;
|
||||
const unsigned opt_level_;
|
||||
const bool optimize_for_size_;
|
||||
const bool enable_fast_math_;
|
||||
const VectorIntrinsics available_intrinsics_;
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook_;
|
||||
LLVMCompiler::ModuleHook post_optimization_hook_;
|
||||
|
@ -442,6 +442,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
CompilerTargetOptions(module->config()),
|
||||
CodeGenOptLevel(module->config()),
|
||||
options::OptimizeForSizeRequested(module->config()),
|
||||
module->config().debug_options().xla_enable_fast_math(),
|
||||
pre_optimization_ir_hook, post_optimization_ir_hook);
|
||||
llvm_module->setDataLayout(jit->data_layout());
|
||||
llvm_module->setTargetTriple(jit->target_triple().getTriple());
|
||||
@ -794,6 +795,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
CompilerFunctor compiler_functor(
|
||||
target_machine.get(), &disassembler, opt_level,
|
||||
options::OptimizeForSizeRequested(module->config()),
|
||||
module->config().debug_options().xla_enable_fast_math(),
|
||||
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
|
||||
post_optimization_ir_dump_hook);
|
||||
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
||||
|
@ -30,9 +30,33 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
|
||||
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
|
||||
|
||||
namespace {
|
||||
llvm::Value* EmitFMinOrMax(llvm::IRBuilder<>* ir_builder, llvm::Module* module,
|
||||
llvm::Type* vector_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs, bool is_min,
|
||||
bool enable_fast_math) {
|
||||
if (enable_fast_math) {
|
||||
// Using an unordered comparison lets LLVM generate a vminps / vmaxps
|
||||
// instruction on x86. vminps/vmaxps choose the second operand if either
|
||||
// operand is a NaN and thus don't accurately implement the semantics of the
|
||||
// minnum and maxnum intrinsics, necessitating different IR emission.
|
||||
//
|
||||
// We can _probably_ do this even when fast math is disabled, but we can
|
||||
// certainly do this if fast math is enabled (and nnan applies).
|
||||
auto* compare = ir_builder->CreateFCmp(
|
||||
is_min ? llvm::FCmpInst::FCMP_ULE : llvm::FCmpInst::FCMP_UGE, lhs, rhs);
|
||||
return ir_builder->CreateSelect(compare, lhs, rhs);
|
||||
} else {
|
||||
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
|
||||
module, is_min ? llvm::Intrinsic::minnum : llvm::Intrinsic::maxnum,
|
||||
vector_type);
|
||||
return ir_builder->CreateCall(intrinsic, {lhs, rhs});
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
llvm::StringRef function_name,
|
||||
int vector_width) {
|
||||
int vector_width,
|
||||
bool enable_fast_math) {
|
||||
llvm::Function* vector_tanh_function = module->getFunction(function_name);
|
||||
if (vector_tanh_function == nullptr) {
|
||||
// If the function declaration is not present in the module, there can't be
|
||||
@ -45,11 +69,6 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
llvm::VectorType* vector_type =
|
||||
llvm::VectorType::get(float_type, vector_width);
|
||||
|
||||
llvm::Function* min_intrinsic = llvm::Intrinsic::getDeclaration(
|
||||
module, llvm::Intrinsic::minnum, vector_type);
|
||||
llvm::Function* max_intrinsic = llvm::Intrinsic::getDeclaration(
|
||||
module, llvm::Intrinsic::maxnum, vector_type);
|
||||
|
||||
llvm::BasicBlock* vector_tanh_body =
|
||||
llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
|
||||
|
||||
@ -59,15 +78,24 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
fast_math_flags.setUnsafeAlgebra();
|
||||
ir_builder.setFastMathFlags(fast_math_flags);
|
||||
|
||||
auto emit_fmin = [&](llvm::Value* lhs, llvm::Value* rhs) {
|
||||
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
|
||||
/*is_min=*/true,
|
||||
/*enable_fast_math=*/enable_fast_math);
|
||||
};
|
||||
auto emit_fmax = [&](llvm::Value* lhs, llvm::Value* rhs) {
|
||||
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
|
||||
/*is_min=*/false,
|
||||
/*enable_fast_math=*/enable_fast_math);
|
||||
};
|
||||
|
||||
llvm::Value* input = &*vector_tanh_function->arg_begin();
|
||||
CHECK_EQ(input->getType(), vector_type);
|
||||
|
||||
// This implements the same rational interpolant as implemented in Eigen3.
|
||||
llvm::Value* input_clamped = ir_builder.CreateCall(
|
||||
min_intrinsic,
|
||||
{ir_builder.CreateCall(max_intrinsic,
|
||||
{input, llvm::ConstantFP::get(vector_type, -9.0)}),
|
||||
llvm::ConstantFP::get(vector_type, 9.0)});
|
||||
llvm::Value* input_clamped =
|
||||
emit_fmin(emit_fmax(input, llvm::ConstantFP::get(vector_type, -9.0)),
|
||||
llvm::ConstantFP::get(vector_type, 9.0));
|
||||
|
||||
std::array<float, 7> numerator_coeffs(
|
||||
{{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||
@ -105,11 +133,13 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void RewriteIRRuntimeFunctions(llvm::Module* module) {
|
||||
auto* tanh_v4f32 = EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
|
||||
/*vector_width=*/4);
|
||||
auto* tanh_v8f32 = EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
|
||||
/*vector_width=*/8);
|
||||
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
|
||||
auto* tanh_v4f32 =
|
||||
EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
|
||||
/*vector_width=*/4, enable_fast_math);
|
||||
auto* tanh_v8f32 =
|
||||
EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
|
||||
/*vector_width=*/8, enable_fast_math);
|
||||
|
||||
// Gather all the call sites, force inline them and then delete the vector
|
||||
// function bodies.
|
||||
|
@ -33,7 +33,7 @@ extern const char* const kTanhV8F32SymbolName;
|
||||
// |LinkIRRuntimeFunctions| rewrites calls to these functions into generic LLVM
|
||||
// IR.
|
||||
|
||||
void RewriteIRRuntimeFunctions(llvm::Module* module);
|
||||
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math);
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
|
@ -171,7 +171,7 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||
|
||||
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
llvm::CodeGenOpt::Level opt_level,
|
||||
bool optimize_for_size,
|
||||
bool optimize_for_size, bool enable_fast_math,
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||
LLVMCompiler::ModuleHook post_optimization_hook)
|
||||
: target_machine_(
|
||||
@ -186,12 +186,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
data_layout_(target_machine_->createDataLayout()),
|
||||
object_layer_(
|
||||
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
|
||||
compile_layer_(
|
||||
object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
|
||||
optimize_for_size, GetAvailableIntrinsics(),
|
||||
std::move(pre_optimization_hook),
|
||||
std::move(post_optimization_hook))) {
|
||||
compile_layer_(object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_,
|
||||
opt_level, optimize_for_size,
|
||||
enable_fast_math, GetAvailableIntrinsics(),
|
||||
std::move(pre_optimization_hook),
|
||||
std::move(post_optimization_hook))) {
|
||||
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
||||
<< " features: " << target_machine_->getTargetFeatureString().str();
|
||||
}
|
||||
|
@ -63,6 +63,7 @@ class SimpleOrcJIT {
|
||||
// level optimizations are applied.
|
||||
SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
|
||||
bool enable_fast_math,
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook,
|
||||
LLVMCompiler::ModuleHook post_optimization_hook);
|
||||
|
||||
|
@ -292,6 +292,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
// TODO(b/64580527): We can do better here if fast-math is enabled.
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
|
||||
{lhs_value, rhs_value},
|
||||
{lhs_value->getType()}, ir_builder_);
|
||||
@ -299,6 +300,7 @@ llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
// TODO(b/64580527): We can do better here if fast-math is enabled.
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
|
||||
{lhs_value, rhs_value},
|
||||
{lhs_value->getType()}, ir_builder_);
|
||||
|
Loading…
Reference in New Issue
Block a user