Change XLA's default to disable cpu_fast_math options with the exception of min_max behavior.
This is due to issues around inf/nan behavior on the cpu. tf_library still enables all fast math options though with the observation that currently most users of this are desiring performance and have tested their code already. PiperOrigin-RevId: 311787817 Change-Id: Iab012d49435845dc5b7a5fcedca89bf159ec65a3
This commit is contained in:
parent
08968c30dc
commit
2bbf57217f
|
@ -42,7 +42,8 @@ def tf_library(
|
|||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
|
||||
math enabled on cpu.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
|
@ -207,6 +208,15 @@ def tf_library(
|
|||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
default_fast_math_xla_flags = "XLA_FLAGS=\"\
|
||||
--xla_cpu_enable_fast_math=true \
|
||||
--xla_cpu_fast_math_honor_nans=false \
|
||||
--xla_cpu_fast_math_honor_infs=false \
|
||||
--xla_cpu_fast_math_honor_functions=false \
|
||||
--xla_cpu_fast_math_honor_division=false \
|
||||
--xla_cpu_enable_fast_min_max=true \
|
||||
$${XLA_FLAGS:-}\" "
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = srcs,
|
||||
|
@ -216,6 +226,7 @@ def tf_library(
|
|||
function_object_file,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
@ -256,6 +267,7 @@ def tf_library(
|
|||
session_module_pb,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
|
|
@ -67,6 +67,8 @@ int main(int argc, char** argv) {
|
|||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
|
||||
// generally the preferred case.
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
|
|
@ -55,9 +55,16 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
|||
// b/77879207.
|
||||
opts.set_xla_gpu_disable_multi_streaming(true);
|
||||
|
||||
// TODO(jlebar): Disable fastmath once doing so is not a performance
|
||||
// regression.
|
||||
// Disable forms of fast math that have caused users problems in the past.
|
||||
opts.set_xla_cpu_enable_fast_math(true);
|
||||
opts.set_xla_cpu_fast_math_honor_nans(true);
|
||||
opts.set_xla_cpu_fast_math_honor_infs(true);
|
||||
opts.set_xla_cpu_fast_math_honor_functions(true);
|
||||
opts.set_xla_cpu_fast_math_honor_division(true);
|
||||
|
||||
// By default, copy TF's Eigen style min_max behavior with nans.
|
||||
opts.set_xla_cpu_enable_fast_min_max(false);
|
||||
|
||||
opts.set_xla_gpu_enable_fast_min_max(true);
|
||||
|
||||
opts.set_xla_allow_excess_precision(true);
|
||||
|
@ -261,6 +268,12 @@ static void AllocateFlags() {
|
|||
"When xla_cpu_enable_fast_math is true then this controls whether we "
|
||||
"forbid to approximate calculations for functions. Ignored when "
|
||||
"xla_cpu_enable_fast_math is false."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_cpu_enable_fast_min_max",
|
||||
bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max),
|
||||
flag_values->xla_cpu_enable_fast_min_max(),
|
||||
"Enable fast floating point min/max lowering that always propagates "
|
||||
"NaNs."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_enable_fast_min_max",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
|
||||
|
|
|
@ -872,11 +872,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||
DebugOptions* debug_options =
|
||||
options.executable_build_options.mutable_debug_options();
|
||||
// Sets fast-math-disabling default options expected by JAX.
|
||||
// TODO(phawkins): make these XLA-wide defaults.
|
||||
debug_options->set_xla_cpu_fast_math_honor_infs(true);
|
||||
debug_options->set_xla_cpu_fast_math_honor_nans(true);
|
||||
debug_options->set_xla_cpu_fast_math_honor_division(true);
|
||||
debug_options->set_xla_cpu_fast_math_honor_functions(true);
|
||||
debug_options->set_xla_cpu_enable_fast_min_max(false);
|
||||
debug_options->set_xla_gpu_enable_fast_min_max(false);
|
||||
return options;
|
||||
}))
|
||||
|
|
|
@ -70,6 +70,13 @@ class CpuUnaryIntrinsicTest
|
|||
return absl::StrCat(opcode, "_On_", triple,
|
||||
(features.empty() ? "" : "_With"), features);
|
||||
}
|
||||
|
||||
private:
|
||||
DebugOptions GetDebugOptionsForTest() override {
|
||||
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
HloTestBase::SetAotFastMathDebugOptions(&debug_options);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
// Creates a module with a call to the unary op, and tests if the
|
||||
|
|
|
@ -69,6 +69,13 @@ class CpuVectorizationTest
|
|||
return absl::StrCat(opcode, "_On_", triple,
|
||||
(features.empty() ? "" : "_With"), features);
|
||||
}
|
||||
|
||||
private:
|
||||
DebugOptions GetDebugOptionsForTest() override {
|
||||
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
HloTestBase::SetAotFastMathDebugOptions(&debug_options);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(CpuVectorizationTest, DoIt) {
|
||||
|
|
|
@ -64,6 +64,7 @@ cc_library(
|
|||
srcs = ["llvm_util.cc"],
|
||||
hdrs = ["llvm_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
|
|
@ -58,7 +58,7 @@ ENTRY while3 {
|
|||
|
||||
CompileAndVerifyIr(hlo_string, R"(
|
||||
; CHECK-LABEL: @body(i8* %retval
|
||||
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
|
||||
; CHECK: %[[add_result:.*]] = fadd reassoc nsz contract float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
|
||||
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], align 4, !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
|
||||
;
|
||||
; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
|
@ -90,7 +91,9 @@ llvm::CallInst* EmitCallToIntrinsic(
|
|||
|
||||
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
if (b->getFastMathFlags().noNaNs()) {
|
||||
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
|
||||
if (b->getFastMathFlags().noNaNs() ||
|
||||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
|
||||
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
} else {
|
||||
|
@ -103,7 +106,9 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
|||
|
||||
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
if (b->getFastMathFlags().noNaNs()) {
|
||||
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
|
||||
if (b->getFastMathFlags().noNaNs() ||
|
||||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
|
||||
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
} else {
|
||||
|
|
|
@ -76,6 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test {
|
|||
void SetFastMathDisabled(bool disabled) {
|
||||
auto* opts = execution_options_.mutable_debug_options();
|
||||
opts->set_xla_cpu_enable_fast_math(!disabled);
|
||||
opts->set_xla_cpu_enable_fast_min_max(!disabled);
|
||||
opts->set_xla_gpu_enable_fast_min_max(!disabled);
|
||||
}
|
||||
|
||||
|
|
|
@ -165,6 +165,16 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
|
|||
return precision_config;
|
||||
}
|
||||
|
||||
void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) {
|
||||
options->set_xla_cpu_enable_fast_math(true);
|
||||
options->set_xla_gpu_enable_fast_min_max(true);
|
||||
options->set_xla_cpu_enable_fast_min_max(true);
|
||||
options->set_xla_cpu_fast_math_honor_nans(false);
|
||||
options->set_xla_cpu_fast_math_honor_infs(false);
|
||||
options->set_xla_cpu_fast_math_honor_functions(false);
|
||||
options->set_xla_cpu_fast_math_honor_division(false);
|
||||
}
|
||||
|
||||
DebugOptions HloTestBase::GetDebugOptionsForTest() {
|
||||
auto debug_options = GetDebugOptionsFromFlags();
|
||||
// TODO(b/38354253): Change tests to use Parameters instead of Constants.
|
||||
|
|
|
@ -100,6 +100,10 @@ class HloTestBase : public ::testing::Test {
|
|||
|
||||
static PrecisionConfig DefaultPrecisionConfig(int operands);
|
||||
|
||||
// Sets most fath math options to be enabled to model the fast math flags
|
||||
// generally used for CPU:AOT compilation.
|
||||
static void SetAotFastMathDebugOptions(DebugOptions* options);
|
||||
|
||||
protected:
|
||||
// This uses the interpreter backend as the reference backend and
|
||||
// automatically finds another supported backend as the test backend. If the
|
||||
|
|
|
@ -310,8 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
|
|||
|
||||
XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) {
|
||||
XlaBuilder builder(TestName());
|
||||
mutable_debug_options()->set_xla_cpu_enable_fast_math(false);
|
||||
mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false);
|
||||
SetFastMathDisabled(true);
|
||||
auto low = ConstantR1<float>(&builder, {NAN, 1, 1});
|
||||
auto high = ConstantR1<float>(&builder, {3, NAN, 3});
|
||||
auto x = ConstantR1<float>(&builder, {2, 2, NAN});
|
||||
|
|
|
@ -148,9 +148,20 @@ message DebugOptions {
|
|||
// xla_cpu_enable_fast_math is false.
|
||||
bool xla_cpu_fast_math_honor_functions = 129;
|
||||
|
||||
// When false we lower the Minimum and Maximum hlos in the CPU backend such
|
||||
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag
|
||||
// this is false we always propagate NaNs through Min and Max.
|
||||
//
|
||||
// Note, this does not correspond to the exact same behavior as the gpu flag
|
||||
// below!
|
||||
bool xla_cpu_enable_fast_min_max = 140;
|
||||
|
||||
// When true we lower the Minimum and Maximum hlos in the GPU backend such
|
||||
// that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag
|
||||
// this is true we don't propagate NaNs through Min and Max.
|
||||
//
|
||||
// Note, this does not correspond to the exact same behavior as the cpu flag
|
||||
// above!
|
||||
bool xla_gpu_enable_fast_min_max = 100;
|
||||
|
||||
// Allows xla to increase the output precision of floating point operations.
|
||||
|
@ -280,7 +291,7 @@ message DebugOptions {
|
|||
// memory, or have bugs.
|
||||
bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139;
|
||||
|
||||
// Next id: 140
|
||||
// Next id: 141
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
|
|
@ -55,8 +55,8 @@ class BetaincTest(test.TestCase):
|
|||
# the scipy version of betainc uses a double-only implementation.
|
||||
# TODO(ebrevdo): identify reasons for (sometime) precision loss
|
||||
# with doubles
|
||||
rtol = 1e-4 if dtype == dtypes.float32 else 5e-5
|
||||
atol = 9e-6 if dtype == dtypes.float32 else 3e-6
|
||||
rtol = 1e-4
|
||||
atol = 1e-5
|
||||
self.assertAllCloseAccordingToType(
|
||||
scipy_out, tf_out, rtol=rtol, atol=atol)
|
||||
|
||||
|
@ -66,7 +66,8 @@ class BetaincTest(test.TestCase):
|
|||
with self.cached_session():
|
||||
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
|
||||
scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt)
|
||||
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
|
||||
self.assertAllCloseAccordingToType(
|
||||
scipy_comb, tf_comb, rtol=rtol, atol=atol)
|
||||
|
||||
# Test broadcasting between scalars and other shapes
|
||||
with self.cached_session():
|
||||
|
|
|
@ -149,7 +149,7 @@ class GradientCheckerTest(test.TestCase):
|
|||
self.assertAllEqual(correct, analytical)
|
||||
self.assertAllClose(correct, numerical, rtol=1e-4)
|
||||
self.assertLess(
|
||||
gradient_checker.compute_gradient_error(x, size, y, size), 2e-4)
|
||||
gradient_checker.compute_gradient_error(x, size, y, size), 3e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testComplexConj(self):
|
||||
|
|
Loading…
Reference in New Issue