diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index abccefbcdbb..f0c3e7da0ba 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -42,7 +42,8 @@ def tf_library( mlir_components = "None", deps = None, tags = []): - """Runs tfcompile to compile a TensorFlow graph into executable code. + """Runs tfcompile to compile a TensorFlow graph into executable code with fast + math enabled on cpu. Given an invocation of tf_library(name="foo", ...), generates the following build targets: @@ -207,6 +208,15 @@ def tf_library( srcs.append(debug_info) debug_info_flag = " --debug_info=$(location " + debug_info + ")" + default_fast_math_xla_flags = "XLA_FLAGS=\"\ + --xla_cpu_enable_fast_math=true \ + --xla_cpu_fast_math_honor_nans=false \ + --xla_cpu_fast_math_honor_infs=false \ + --xla_cpu_fast_math_honor_functions=false \ + --xla_cpu_fast_math_honor_division=false \ + --xla_cpu_enable_fast_min_max=true \ + $${XLA_FLAGS:-}\" " + native.genrule( name = ("gen_" + name), srcs = srcs, @@ -216,6 +226,7 @@ def tf_library( function_object_file, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -256,6 +267,7 @@ def tf_library( session_module_pb, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f0cf8f2ded9..846947454bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -67,6 +67,8 @@ int main(int argc, char** argv) { flags.entry_point = "entry"; flags.debug_info_path_begin_marker = ""; + // Note that tfcompile.bzl's tf_library macro sets fast math flags as that is + // generally the preferred case. std::vector flag_list; AppendMainFlags(&flag_list, &flags); xla::AppendDebugOptionsFlags(&flag_list); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 60a563ee956..4152982bf4c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -55,9 +55,16 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // b/77879207. opts.set_xla_gpu_disable_multi_streaming(true); - // TODO(jlebar): Disable fastmath once doing so is not a performance - // regression. + // Disable forms of fast math that have caused users problems in the past. opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_cpu_fast_math_honor_nans(true); + opts.set_xla_cpu_fast_math_honor_infs(true); + opts.set_xla_cpu_fast_math_honor_functions(true); + opts.set_xla_cpu_fast_math_honor_division(true); + + // By default, copy TF's Eigen style min_max behavior with nans. + opts.set_xla_cpu_enable_fast_min_max(false); + opts.set_xla_gpu_enable_fast_min_max(true); opts.set_xla_allow_excess_precision(true); @@ -261,6 +268,12 @@ static void AllocateFlags() { "When xla_cpu_enable_fast_math is true then this controls whether we " "forbid to approximate calculations for functions. Ignored when " "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max), + flag_values->xla_cpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that always propagates " + "NaNs.")); flag_objects->push_back(tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 65fb5311994..f10ec978399 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -872,11 +872,7 @@ PYBIND11_MODULE(xla_extension, m) { DebugOptions* debug_options = options.executable_build_options.mutable_debug_options(); // Sets fast-math-disabling default options expected by JAX. - // TODO(phawkins): make these XLA-wide defaults. - debug_options->set_xla_cpu_fast_math_honor_infs(true); - debug_options->set_xla_cpu_fast_math_honor_nans(true); - debug_options->set_xla_cpu_fast_math_honor_division(true); - debug_options->set_xla_cpu_fast_math_honor_functions(true); + debug_options->set_xla_cpu_enable_fast_min_max(false); debug_options->set_xla_gpu_enable_fast_min_max(false); return options; })) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index b6d6de28bc5..efeab3bd31a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -70,6 +70,13 @@ class CpuUnaryIntrinsicTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; // Creates a module with a call to the unary op, and tests if the diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc index 8a72eb15487..757d878e224 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -69,6 +69,13 @@ class CpuVectorizationTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; TEST_P(CpuVectorizationTest, DoIt) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 39399df7ad8..cabcc8e06ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -64,6 +64,7 @@ cc_library( srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 453a5cd84b2..f7808773592 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -58,7 +58,7 @@ ENTRY while3 { CompileAndVerifyIr(hlo_string, R"( ; CHECK-LABEL: @body(i8* %retval -; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: %[[add_result:.*]] = fadd reassoc nsz contract float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], align 4, !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 4c9a8d3e004..c2b11819448 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -90,7 +91,9 @@ llvm::CallInst* EmitCallToIntrinsic( llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -103,7 +106,9 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 5b83186ffa4..790497f888e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,6 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_cpu_enable_fast_min_max(!disabled); opts->set_xla_gpu_enable_fast_min_max(!disabled); } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 8eed609a134..7b64be5597b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -165,6 +165,16 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { return precision_config; } +void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) { + options->set_xla_cpu_enable_fast_math(true); + options->set_xla_gpu_enable_fast_min_max(true); + options->set_xla_cpu_enable_fast_min_max(true); + options->set_xla_cpu_fast_math_honor_nans(false); + options->set_xla_cpu_fast_math_honor_infs(false); + options->set_xla_cpu_fast_math_honor_functions(false); + options->set_xla_cpu_fast_math_honor_division(false); +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d05776a0cb9..85b1876dd3c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -100,6 +100,10 @@ class HloTestBase : public ::testing::Test { static PrecisionConfig DefaultPrecisionConfig(int operands); + // Sets most fath math options to be enabled to model the fast math flags + // generally used for CPU:AOT compilation. + static void SetAotFastMathDebugOptions(DebugOptions* options); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3407a68f709..40e226f9902 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -310,8 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { XlaBuilder builder(TestName()); - mutable_debug_options()->set_xla_cpu_enable_fast_math(false); - mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false); + SetFastMathDisabled(true); auto low = ConstantR1(&builder, {NAN, 1, 1}); auto high = ConstantR1(&builder, {3, NAN, 3}); auto x = ConstantR1(&builder, {2, 2, NAN}); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f4b08f454b9..9374b1fca6a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -148,9 +148,20 @@ message DebugOptions { // xla_cpu_enable_fast_math is false. bool xla_cpu_fast_math_honor_functions = 129; + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the cpu flag + // above! bool xla_gpu_enable_fast_min_max = 100; // Allows xla to increase the output precision of floating point operations. @@ -280,7 +291,7 @@ message DebugOptions { // memory, or have bugs. bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139; - // Next id: 140 + // Next id: 141 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index c4f70b5bc29..c564c822918 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -55,8 +55,8 @@ class BetaincTest(test.TestCase): # the scipy version of betainc uses a double-only implementation. # TODO(ebrevdo): identify reasons for (sometime) precision loss # with doubles - rtol = 1e-4 if dtype == dtypes.float32 else 5e-5 - atol = 9e-6 if dtype == dtypes.float32 else 3e-6 + rtol = 1e-4 + atol = 1e-5 self.assertAllCloseAccordingToType( scipy_out, tf_out, rtol=rtol, atol=atol) @@ -66,7 +66,8 @@ class BetaincTest(test.TestCase): with self.cached_session(): tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval() scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt) - self.assertAllCloseAccordingToType(scipy_comb, tf_comb) + self.assertAllCloseAccordingToType( + scipy_comb, tf_comb, rtol=rtol, atol=atol) # Test broadcasting between scalars and other shapes with self.cached_session(): diff --git a/tensorflow/python/ops/gradient_checker_test.py b/tensorflow/python/ops/gradient_checker_test.py index 92ca9c2971e..c8ebf12569a 100644 --- a/tensorflow/python/ops/gradient_checker_test.py +++ b/tensorflow/python/ops/gradient_checker_test.py @@ -149,7 +149,7 @@ class GradientCheckerTest(test.TestCase): self.assertAllEqual(correct, analytical) self.assertAllClose(correct, numerical, rtol=1e-4) self.assertLess( - gradient_checker.compute_gradient_error(x, size, y, size), 2e-4) + gradient_checker.compute_gradient_error(x, size, y, size), 3e-4) @test_util.run_deprecated_v1 def testComplexConj(self):