From dcf8ec6a9578af9a3ac1eae535248b77db120019 Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Mon, 23 Dec 2019 16:39:15 -0800 Subject: [PATCH 1/3] Replace --xla_gpu_disable_autotune option with --xla_gpu_autotune_level XLA performs autotuning for convolutions and GEMMs. for each algorithm it runs some checks to check for out of bounds access or functional errors. The latter in particular can take a lot of time, increasing compile time by orders of magnitude. This hurts end-to-end execution time. --xla_gpu_autotune_level enables clients to set a level. 0: don't autotune 1: autotune with uninitialized data 2: autotune with initialized data 3: autotune with initialized data, and reinitialize for inplace case 4: autotune with initialized data, reinitialize, and check The deafult is 4, not changing the current behaviour. Change some tests accordingly. --- tensorflow/compiler/xla/debug_options_flags.cc | 10 ++++++---- .../xla/service/gpu/gemm_algorithm_picker.cc | 18 +++++++++++++++--- .../service/gpu/gpu_conv_algorithm_picker.cc | 18 +++++++++++++----- tensorflow/compiler/xla/tests/BUILD | 8 ++++---- tensorflow/compiler/xla/xla.proto | 2 +- 5 files changed, 39 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index fe2b2ca6dc4..c034948016f 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -34,6 +34,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_llvm_enable_invariant_load_metadata(true); opts.set_xla_llvm_disable_expensive_passes(false); opts.set_xla_backend_optimization_level(3); + opts.set_xla_gpu_autotune_level(4); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_eliminate_hlo_implicit_broadcast(true); @@ -399,10 +400,11 @@ static void AllocateFlags() { "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), tensorflow::Flag( - "xla_gpu_disable_autotune", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune), - flag_values->xla_gpu_disable_autotune(), - "Disable GEMM and Convolution auto-tuning."), + "xla_gpu_autotune_level", + int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), + flag_values->xla_gpu_autotune_level(), + "Set GEMM and Convolution auto-tuning level." + "0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."), tensorflow::Flag( "xla_force_host_platform_device_count", int32_setter_for( diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index cb5f0dc1112..0a1f377278e 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -69,6 +69,10 @@ static StatusOr> DoUncachedGemmAutotune( GemmBackendConfig backend_config = gemm->backend_config().ValueOrDie(); + int32 cublas_level = + gemm->GetModule()->config().debug_options().xla_gpu_autotune_level(); + bool reinit_cublas_data = cublas_level > 2; + bool check_cublas = cublas_level > 3; VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); @@ -81,7 +85,7 @@ static StatusOr> DoUncachedGemmAutotune( for (se::blas::AlgorithmType algorithm : algorithms) { // Make sure the output buffer always has the same value if we use // the bias parameter. - if (backend_config.beta() != 0) { + if (reinit_cublas_data && backend_config.beta() != 0) { int64 rng_state = 0; InitializeBuffer(stream, gemm->shape().element_type(), &rng_state, output_buffer); @@ -114,6 +118,10 @@ static StatusOr> DoUncachedGemmAutotune( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); + if (!check_cublas) { + continue; + } + TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, allocator.CheckRedzones()); @@ -248,6 +256,8 @@ static StatusOr RunOnInstruction(HloInstruction* instr, allocator->GetStream(executor->device_ordinal())); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + bool init_cublas_data = + hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config), /*memory_limit=*/std::numeric_limits::max()); @@ -260,7 +270,9 @@ static StatusOr RunOnInstruction(HloInstruction* instr, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(op->shape()))); - InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + if (init_cublas_data) { + InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + } return buffer; }; @@ -316,7 +328,7 @@ static StatusOr RunOnComputation(HloComputation* computation, StatusOr GemmAlgorithmPicker::Run(HloModule* module) { XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker"); - if (module->config().debug_options().xla_gpu_disable_autotune()) { + if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early"; return false; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 71a86207987..0ac7c236da9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -320,14 +320,18 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const Shape& result_shape = instr->shape().tuple_shapes(0); int64 rng_state = 0; - const auto initialize_buffer = [&stream, &rng_state]( + const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + int32 conv_level = hlo_module_config.debug_options().xla_gpu_autotune_level(); + bool init_conv_data = conv_level > 1; + bool check_conv = conv_level > 3; + const auto initialize_buffer = [init_conv_data, &stream, &rng_state]( DeviceMemoryBase buffer, const Shape& buffer_shape) { - InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); + if (init_conv_data) { + InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); + } }; - const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - // Allocate space for the input, filter, and output of the convolution. se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); @@ -421,6 +425,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); + if (!check_conv) { + continue; + } + // Check for writes to redzones. TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, CheckRedzones(input_output_allocator, stream, @@ -718,7 +726,7 @@ StatusOr GpuConvAlgorithmPicker::RunOnComputation( StatusOr GpuConvAlgorithmPicker::Run(HloModule* module) { XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker"); - if (module->config().debug_options().xla_gpu_disable_autotune()) { + if (module->config().debug_options().xla_gpu_autotune_level() == 0) { VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker " "returning early."; return false; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b2cc8050c42..77f4d57577f 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -949,11 +949,11 @@ xla_test( ) # Run dot tests with auto-tuning disabled. This just does a basic sanity check -# that enabling xla_gpu_disable_autotune does not break simple graphs. +# that setting xla_gpu_autotune_level to 0 does not break simple graphs. xla_test( name = "dot_operation_test_autotune_disabled", srcs = ["dot_operation_test.cc"], - args = ["--xla_gpu_disable_autotune"], + args = ["--xla_gpu_autotune_level=0"], backends = ["gpu"], shard_count = 20, tags = [ @@ -1121,13 +1121,13 @@ xla_test( ) # Run convolution tests with auto-tuning disabled. This just does a basic -# sanity check that enabling xla_gpu_disable_autotune does not break simple +# sanity check that setting xla_gpu_autotune_level to 0 does not break simple # graphs. xla_test( name = "convolution_test_autotune_disabled", timeout = "long", srcs = ["convolution_test.cc"], - args = ["--xla_gpu_disable_autotune"], + args = ["--xla_gpu_autotune_level=0"], backends = ["gpu"], shard_count = 40, tags = ["optonly"], diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 29357615bd2..259c3290ed6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -158,7 +158,7 @@ message DebugOptions { bool xla_gpu_crash_on_verification_failures = 101; // Disable GEMM and Convolution auto-tuning. - bool xla_gpu_disable_autotune = 123; + int32 xla_gpu_autotune_level = 123; // Force the host platform to pretend that there are these many host // "devices". All these devices are backed by the same threadpool. Defaults From 3bc238c04ac8d61ceb8437dd01ffa3278906e35f Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Thu, 9 Jan 2020 10:09:08 -0800 Subject: [PATCH 2/3] address comments on commit 9be66fce9aa03bc4344422f59217b3bdf871311d --- .../compiler/xla/service/gpu/gemm_algorithm_picker.cc | 8 ++++---- .../compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 0a1f377278e..b1f1797fbdf 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -69,10 +69,10 @@ static StatusOr> DoUncachedGemmAutotune( GemmBackendConfig backend_config = gemm->backend_config().ValueOrDie(); - int32 cublas_level = + const int32 cublas_autotune_level = gemm->GetModule()->config().debug_options().xla_gpu_autotune_level(); - bool reinit_cublas_data = cublas_level > 2; - bool check_cublas = cublas_level > 3; + const bool reinit_cublas_data = cublas_autotune_level > 2; + const bool check_cublas = cublas_autotune_level > 3; VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); @@ -256,7 +256,7 @@ static StatusOr RunOnInstruction(HloInstruction* instr, allocator->GetStream(executor->device_ordinal())); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - bool init_cublas_data = + const bool init_cublas_data = hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 0ac7c236da9..0693ab5b6a3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -321,9 +321,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( int64 rng_state = 0; const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - int32 conv_level = hlo_module_config.debug_options().xla_gpu_autotune_level(); - bool init_conv_data = conv_level > 1; - bool check_conv = conv_level > 3; + const int32 conv_autotune_level = + hlo_module_config.debug_options().xla_gpu_autotune_level(); + const bool init_conv_data = conv_autotune_level > 1; + const bool check_conv = conv_autotune_level > 3; const auto initialize_buffer = [init_conv_data, &stream, &rng_state]( DeviceMemoryBase buffer, const Shape& buffer_shape) { From 9cbf6a57d4281ef6a766423b102a2825eb5dedb4 Mon Sep 17 00:00:00 2001 From: Bas Aarts Date: Thu, 9 Jan 2020 12:37:53 -0800 Subject: [PATCH 3/3] Rebase against upstream --- tensorflow/python/framework/test_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 8c560e4aa8c..20206536bdc 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1642,9 +1642,9 @@ def disable_cudnn_autotune(func): original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" original_xla_flags = os.environ.get("XLA_FLAGS") - new_xla_flags = "--xla_gpu_disable_autotune" + new_xla_flags = "--xla_gpu_autotune_level=0" if original_xla_flags: - new_xla_flags += " " + original_xla_flags + new_xla_flags = original_xla_flags + " " + new_xla_flags os.environ["XLA_FLAGS"] = new_xla_flags result = f(self, *args, **kwargs)