[XLA:CPU] Remove the global/module-level fast math flags
These are deprecated in favor of instruction-level fast math, and most of LLVM's backend code was updated to use those instead. Not having them gives us more fine-grained control of fast math flags without loss of performance. Disabling UnsafeFPMath has the side effect of requiring __truncdfhf2 for double->half conversions, so provide that. Also always allow FMA formation, while it's not IEEE754 compliant it never decreases accuracy. PiperOrigin-RevId: 281801638 Change-Id: I2d96220fefebad4d11b1dab8f75b06ccb88a05bf
This commit is contained in:
parent
1703690e1e
commit
d04bfee679
@ -409,20 +409,8 @@ auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; };
|
|||||||
llvm::TargetOptions CompilerTargetOptions(
|
llvm::TargetOptions CompilerTargetOptions(
|
||||||
const HloModuleConfig& module_config) {
|
const HloModuleConfig& module_config) {
|
||||||
llvm::TargetOptions target_options;
|
llvm::TargetOptions target_options;
|
||||||
// In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc.
|
// Always allow FMA fusion. This increases precision instead of decreasing it.
|
||||||
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
|
target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
|
||||||
target_options.UnsafeFPMath = true;
|
|
||||||
target_options.NoInfsFPMath =
|
|
||||||
!module_config.debug_options().xla_cpu_fast_math_honor_infs();
|
|
||||||
target_options.NoNaNsFPMath =
|
|
||||||
!module_config.debug_options().xla_cpu_fast_math_honor_nans();
|
|
||||||
target_options.NoSignedZerosFPMath = true;
|
|
||||||
} else {
|
|
||||||
target_options.UnsafeFPMath = false;
|
|
||||||
target_options.NoInfsFPMath = false;
|
|
||||||
target_options.NoNaNsFPMath = false;
|
|
||||||
target_options.NoSignedZerosFPMath = false;
|
|
||||||
}
|
|
||||||
return target_options;
|
return target_options;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -131,3 +131,9 @@ float TF_ATTRIBUTE_WEAK __gnu_h2f_ieee(uint16 h) {
|
|||||||
o.set_uint(o.as_uint() | (h & 0x8000) << 16); // sign bit
|
o.set_uint(o.as_uint() | (h & 0x8000) << 16); // sign bit
|
||||||
return o.as_float();
|
return o.as_float();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint16 TF_ATTRIBUTE_WEAK __truncdfhf2(double d) {
|
||||||
|
// This does a double rounding step, but it's precise enough for our use
|
||||||
|
// cases.
|
||||||
|
return __gnu_f2h_ieee(static_cast<float>(d));
|
||||||
|
}
|
||||||
|
|||||||
@ -24,4 +24,7 @@ extern "C" tensorflow::uint16 __gnu_f2h_ieee(float);
|
|||||||
// Converts an F16 value to a F32.
|
// Converts an F16 value to a F32.
|
||||||
extern "C" float __gnu_h2f_ieee(tensorflow::uint16);
|
extern "C" float __gnu_h2f_ieee(tensorflow::uint16);
|
||||||
|
|
||||||
|
// Converts an F64 value to a F16.
|
||||||
|
extern "C" tensorflow::uint16 __truncdfhf2(double);
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_
|
||||||
|
|||||||
@ -250,6 +250,8 @@ bool RegisterKnownJITSymbols() {
|
|||||||
"Host");
|
"Host");
|
||||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
||||||
"Host");
|
"Host");
|
||||||
|
registry->Register("__truncdfhf2", reinterpret_cast<void*>(__truncdfhf2),
|
||||||
|
"Host");
|
||||||
|
|
||||||
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
||||||
|
|
||||||
|
|||||||
@ -607,20 +607,6 @@ llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
|
|||||||
// created by the JIT compiled code.
|
// created by the JIT compiled code.
|
||||||
function->setHasUWTable();
|
function->setHasUWTable();
|
||||||
|
|
||||||
if (module_config.debug_options().xla_cpu_enable_fast_math()) {
|
|
||||||
function->addFnAttr("unsafe-fp-math", "true");
|
|
||||||
function->addFnAttr("no-signed-zeros-fp-math", "true");
|
|
||||||
if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
|
|
||||||
function->addFnAttr("no-nans-fp-math", "true");
|
|
||||||
}
|
|
||||||
if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
|
|
||||||
function->addFnAttr("no-infs-fp-math", "true");
|
|
||||||
}
|
|
||||||
if (module_config.debug_options().xla_cpu_fast_math_honor_division()) {
|
|
||||||
function->addFnAttr("reciprocal-estimates", "none");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the optize attribute to the function if optimizing for size. This
|
// Add the optize attribute to the function if optimizing for size. This
|
||||||
// controls internal behavior of some optimization passes (e.g. loop
|
// controls internal behavior of some optimization passes (e.g. loop
|
||||||
// unrolling).
|
// unrolling).
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user