diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index fe2b2ca6dc4..4b5aa384e22 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -34,6 +34,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_llvm_enable_invariant_load_metadata(true); opts.set_xla_llvm_disable_expensive_passes(false); opts.set_xla_backend_optimization_level(3); + opts.set_xla_gpu_autotune_level(4); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_eliminate_hlo_implicit_broadcast(true); @@ -399,10 +400,12 @@ static void AllocateFlags() { "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), tensorflow::Flag( - "xla_gpu_disable_autotune", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune), - flag_values->xla_gpu_disable_autotune(), - "Disable GEMM and Convolution auto-tuning."), + "xla_gpu_autotune_level", + int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), + flag_values->xla_gpu_autotune_level(), + "Set GEMM and Convolution auto-tuning level." + "0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = " + "on+init+reinit+check."), tensorflow::Flag( "xla_force_host_platform_device_count", int32_setter_for( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index cb5f0dc1112..de67b115ff7 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -69,6 +69,10 @@ static StatusOr> DoUncachedGemmAutotune( GemmBackendConfig backend_config = gemm->backend_config().ValueOrDie(); + const int32 cublas_autotune_level = + gemm->GetModule()->config().debug_options().xla_gpu_autotune_level(); + const bool reinit_cublas_data = cublas_autotune_level > 2; + const bool check_cublas = cublas_autotune_level > 3; VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); @@ -81,7 +85,7 @@ static StatusOr> DoUncachedGemmAutotune( for (se::blas::AlgorithmType algorithm : algorithms) { // Make sure the output buffer always has the same value if we use // the bias parameter. - if (backend_config.beta() != 0) { + if (reinit_cublas_data && backend_config.beta() != 0) { int64 rng_state = 0; InitializeBuffer(stream, gemm->shape().element_type(), &rng_state, output_buffer); @@ -114,6 +118,10 @@ static StatusOr> DoUncachedGemmAutotune( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); + if (!check_cublas) { + continue; + } + TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, allocator.CheckRedzones()); @@ -248,6 +256,8 @@ static StatusOr RunOnInstruction(HloInstruction* instr, allocator->GetStream(executor->device_ordinal())); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + const bool init_cublas_data = + hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config), /*memory_limit=*/std::numeric_limits::max()); @@ -260,7 +270,9 @@ static StatusOr RunOnInstruction(HloInstruction* instr, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(op->shape()))); - InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + if (init_cublas_data) { + InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + } return buffer; }; @@ -316,7 +328,7 @@ static StatusOr RunOnComputation(HloComputation* computation, StatusOr GemmAlgorithmPicker::Run(HloModule* module) { XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker"); - if (module->config().debug_options().xla_gpu_disable_autotune()) { + if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early"; return false; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index fea06eed025..4562996a65f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -343,14 +343,19 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const Shape& result_shape = instr->shape().tuple_shapes(0); int64 rng_state = 0; - const auto initialize_buffer = [&stream, &rng_state]( + const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + const int32 conv_autotune_level = + hlo_module_config.debug_options().xla_gpu_autotune_level(); + const bool init_conv_data = conv_autotune_level > 1; + const bool check_conv = conv_autotune_level > 3; + const auto initialize_buffer = [init_conv_data, &stream, &rng_state]( DeviceMemoryBase buffer, const Shape& buffer_shape) { - InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); + if (init_conv_data) { + InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); + } }; - const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - // Allocate space for the input, filter, and output of the convolution. se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); @@ -444,6 +449,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); + if (!check_conv) { + continue; + } + // Check for writes to redzones. TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, CheckRedzones(input_output_allocator, stream, @@ -780,7 +789,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnComputation( StatusOr GpuConvAlgorithmPicker::Run(HloModule* module) { XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker"); - if (module->config().debug_options().xla_gpu_disable_autotune()) { + if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker " "returning early."; return false; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index a1676bf54cf..0593b6195fa 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -972,11 +972,11 @@ xla_test( ) # Run dot tests with auto-tuning disabled. This just does a basic sanity check -# that enabling xla_gpu_disable_autotune does not break simple graphs. +# that setting xla_gpu_autotune_level to 0 does not break simple graphs. xla_test( name = "dot_operation_test_autotune_disabled", srcs = ["dot_operation_test.cc"], - args = ["--xla_gpu_disable_autotune"], + args = ["--xla_gpu_autotune_level=0"], backends = ["gpu"], shard_count = 20, tags = [ @@ -1151,13 +1151,13 @@ xla_test( ) # Run convolution tests with auto-tuning disabled. This just does a basic -# sanity check that enabling xla_gpu_disable_autotune does not break simple +# sanity check that setting xla_gpu_autotune_level to 0 does not break simple # graphs. xla_test( name = "convolution_test_autotune_disabled", timeout = "long", srcs = ["convolution_test.cc"], - args = ["--xla_gpu_disable_autotune"], + args = ["--xla_gpu_autotune_level=0"], backends = ["gpu"], shard_count = 40, tags = [ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 29357615bd2..259c3290ed6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -158,7 +158,7 @@ message DebugOptions { bool xla_gpu_crash_on_verification_failures = 101; // Disable GEMM and Convolution auto-tuning. - bool xla_gpu_disable_autotune = 123; + int32 xla_gpu_autotune_level = 123; // Force the host platform to pretend that there are these many host // "devices". All these devices are backed by the same threadpool. Defaults diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index bf7dbd5936e..f0ec8e5d539 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1646,9 +1646,9 @@ def disable_cudnn_autotune(func): original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" original_xla_flags = os.environ.get("XLA_FLAGS") - new_xla_flags = "--xla_gpu_disable_autotune" + new_xla_flags = "--xla_gpu_autotune_level=0" if original_xla_flags: - new_xla_flags += " " + original_xla_flags + new_xla_flags = original_xla_flags + " " + new_xla_flags os.environ["XLA_FLAGS"] = new_xla_flags result = f(self, *args, **kwargs)