From 256cfa1823b40632fc7df1099d2b5213b465d8de Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Mon, 8 Jun 2020 13:14:33 -0700 Subject: [PATCH] Add XLA_FLAGS=--xla_gpu_gpuasm_extra_flags=... --- tensorflow/compiler/xla/debug_options_flags.cc | 5 +++++ tensorflow/compiler/xla/service/gpu/stream_executor_util.cc | 6 +++++- tensorflow/compiler/xla/xla.proto | 5 ++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 958629c5fa6..368d4d9ac69 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -37,6 +37,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_autotune_level(4); opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_gpu_gpuasm_extra_flags(""); opts.set_xla_eliminate_hlo_implicit_broadcast(true); opts.set_xla_dump_hlo_as_html(false); opts.set_xla_dump_include_timestamp(true); @@ -430,6 +431,10 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), flag_values->xla_gpu_disable_gpuasm_optimizations(), "In XLA:GPU run ptxas in -O0 (default is -O3).")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_gpuasm_extra_flags", string_setter_for(&DebugOptions::set_xla_gpu_gpuasm_extra_flags), + "", //flag_values->xla_gpu_gpuasm_extra_flags(), + "Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA).")); flag_objects->push_back(tensorflow::Flag( "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", "Sets compiler fuel, useful for bisecting bugs in passes. Format " diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 5e351c493ed..bee7622e4da 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -222,9 +222,13 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, } se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { + string extra_string = hlo_module_config.debug_options().xla_gpu_gpuasm_extra_flags(); + std::vector extra_flags; + extra_flags = absl::StrSplit(extra_string, " ", absl::SkipEmpty()); return se::GpuAsmOpts( hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(), - hlo_module_config.debug_options().xla_gpu_cuda_data_dir()); + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(), + extra_flags); } // Unimplemented for integers yet. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 6595bcbe292..62af92a262d 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -287,7 +287,10 @@ message DebugOptions { // memory, or have bugs. bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; - // Next id: 141 + // Extra parameters to pass the GPU assembler. + string xla_gpu_gpuasm_extra_flags = 141; + + // Next id: 142 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.