In fast-math mode emit a tanh that has a faster min/max.

PiperOrigin-RevId: 164943597
This commit is contained in:
A. Unique TensorFlower 2017-08-10 22:09:03 -07:00 committed by TensorFlower Gardener
parent 87605f3d6a
commit c0f9b0a91e
8 changed files with 63 additions and 26 deletions

View File

@ -100,7 +100,7 @@ operator()(llvm::Module& module) const {
CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
runtime::RewriteIRRuntimeFunctions(&module);
runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_);
// Buffer for holding machine code prior to constructing the ObjectFile.
llvm::SmallVector<char, 0> stream_buffer;

View File

@ -42,7 +42,7 @@ class CompilerFunctor {
explicit CompilerFunctor(
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
int opt_level, bool optimize_for_size,
int opt_level, bool optimize_for_size, bool enable_fast_math,
const VectorIntrinsics& available_intrinsics,
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
@ -50,6 +50,7 @@ class CompilerFunctor {
disassembler_(CHECK_NOTNULL(disassembler)),
opt_level_(opt_level),
optimize_for_size_(optimize_for_size),
enable_fast_math_(enable_fast_math),
available_intrinsics_(available_intrinsics),
pre_optimization_hook_(pre_optimization_hook),
post_optimization_hook_(post_optimization_hook) {}
@ -72,6 +73,7 @@ class CompilerFunctor {
const Disassembler* disassembler_;
const unsigned opt_level_;
const bool optimize_for_size_;
const bool enable_fast_math_;
const VectorIntrinsics available_intrinsics_;
LLVMCompiler::ModuleHook pre_optimization_hook_;
LLVMCompiler::ModuleHook post_optimization_hook_;

View File

@ -442,6 +442,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_enable_fast_math(),
pre_optimization_ir_hook, post_optimization_ir_hook);
llvm_module->setDataLayout(jit->data_layout());
llvm_module->setTargetTriple(jit->target_triple().getTriple());
@ -794,6 +795,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
CompilerFunctor compiler_functor(
target_machine.get(), &disassembler, opt_level,
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_enable_fast_math(),
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
post_optimization_ir_dump_hook);
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =

View File

@ -30,9 +30,33 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
namespace {
llvm::Value* EmitFMinOrMax(llvm::IRBuilder<>* ir_builder, llvm::Module* module,
llvm::Type* vector_type, llvm::Value* lhs,
llvm::Value* rhs, bool is_min,
bool enable_fast_math) {
if (enable_fast_math) {
// Using an unordered comparison lets LLVM generate a vminps / vmaxps
// instruction on x86. vminps/vmaxps choose the second operand if either
// operand is a NaN and thus don't accurately implement the semantics of the
// minnum and maxnum intrinsics, necessitating different IR emission.
//
// We can _probably_ do this even when fast math is disabled, but we can
// certainly do this if fast math is enabled (and nnan applies).
auto* compare = ir_builder->CreateFCmp(
is_min ? llvm::FCmpInst::FCMP_ULE : llvm::FCmpInst::FCMP_UGE, lhs, rhs);
return ir_builder->CreateSelect(compare, lhs, rhs);
} else {
llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
module, is_min ? llvm::Intrinsic::minnum : llvm::Intrinsic::maxnum,
vector_type);
return ir_builder->CreateCall(intrinsic, {lhs, rhs});
}
}
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
llvm::StringRef function_name,
int vector_width) {
int vector_width,
bool enable_fast_math) {
llvm::Function* vector_tanh_function = module->getFunction(function_name);
if (vector_tanh_function == nullptr) {
// If the function declaration is not present in the module, there can't be
@ -45,11 +69,6 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
llvm::VectorType* vector_type =
llvm::VectorType::get(float_type, vector_width);
llvm::Function* min_intrinsic = llvm::Intrinsic::getDeclaration(
module, llvm::Intrinsic::minnum, vector_type);
llvm::Function* max_intrinsic = llvm::Intrinsic::getDeclaration(
module, llvm::Intrinsic::maxnum, vector_type);
llvm::BasicBlock* vector_tanh_body =
llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
@ -59,15 +78,24 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
fast_math_flags.setUnsafeAlgebra();
ir_builder.setFastMathFlags(fast_math_flags);
auto emit_fmin = [&](llvm::Value* lhs, llvm::Value* rhs) {
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
/*is_min=*/true,
/*enable_fast_math=*/enable_fast_math);
};
auto emit_fmax = [&](llvm::Value* lhs, llvm::Value* rhs) {
return EmitFMinOrMax(&ir_builder, module, vector_type, lhs, rhs,
/*is_min=*/false,
/*enable_fast_math=*/enable_fast_math);
};
llvm::Value* input = &*vector_tanh_function->arg_begin();
CHECK_EQ(input->getType(), vector_type);
// This implements the same rational interpolant as implemented in Eigen3.
llvm::Value* input_clamped = ir_builder.CreateCall(
min_intrinsic,
{ir_builder.CreateCall(max_intrinsic,
{input, llvm::ConstantFP::get(vector_type, -9.0)}),
llvm::ConstantFP::get(vector_type, 9.0)});
llvm::Value* input_clamped =
emit_fmin(emit_fmax(input, llvm::ConstantFP::get(vector_type, -9.0)),
llvm::ConstantFP::get(vector_type, 9.0));
std::array<float, 7> numerator_coeffs(
{{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
@ -105,11 +133,13 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
}
} // namespace
void RewriteIRRuntimeFunctions(llvm::Module* module) {
auto* tanh_v4f32 = EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
/*vector_width=*/4);
auto* tanh_v8f32 = EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
/*vector_width=*/8);
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
auto* tanh_v4f32 =
EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
/*vector_width=*/4, enable_fast_math);
auto* tanh_v8f32 =
EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
/*vector_width=*/8, enable_fast_math);
// Gather all the call sites, force inline them and then delete the vector
// function bodies.

View File

@ -33,7 +33,7 @@ extern const char* const kTanhV8F32SymbolName;
// |LinkIRRuntimeFunctions| rewrites calls to these functions into generic LLVM
// IR.
void RewriteIRRuntimeFunctions(llvm::Module* module);
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math);
} // namespace runtime
} // namespace cpu

View File

@ -171,7 +171,7 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level,
bool optimize_for_size,
bool optimize_for_size, bool enable_fast_math,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook)
: target_machine_(
@ -186,12 +186,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
data_layout_(target_machine_->createDataLayout()),
object_layer_(
[] { return std::make_shared<llvm::SectionMemoryManager>(); }),
compile_layer_(
object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
optimize_for_size, GetAvailableIntrinsics(),
std::move(pre_optimization_hook),
std::move(post_optimization_hook))) {
compile_layer_(object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_,
opt_level, optimize_for_size,
enable_fast_math, GetAvailableIntrinsics(),
std::move(pre_optimization_hook),
std::move(post_optimization_hook))) {
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
<< " features: " << target_machine_->getTargetFeatureString().str();
}

View File

@ -63,6 +63,7 @@ class SimpleOrcJIT {
// level optimizations are applied.
SimpleOrcJIT(const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level, bool optimize_for_size,
bool enable_fast_math,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook);

View File

@ -292,6 +292,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
// TODO(b/64580527): We can do better here if fast-math is enabled.
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
{lhs_value, rhs_value},
{lhs_value->getType()}, ir_builder_);
@ -299,6 +300,7 @@ llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
// TODO(b/64580527): We can do better here if fast-math is enabled.
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
{lhs_value, rhs_value},
{lhs_value->getType()}, ir_builder_);