Merge pull request from nouiz:vlog_ptxas

PiperOrigin-RevId: 315938190
Change-Id: Ib5ac12f94e3eaac65e97bdfd650af95bedf3f326
This commit is contained in:
TensorFlower Gardener 2020-06-11 11:26:47 -07:00
commit f60069b7bb
4 changed files with 23 additions and 2 deletions
tensorflow

View File

@ -37,6 +37,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
opts.set_xla_gpu_asm_extra_flags("");
opts.set_xla_eliminate_hlo_implicit_broadcast(true);
opts.set_xla_dump_hlo_as_html(false);
opts.set_xla_dump_include_timestamp(true);
@ -430,6 +431,11 @@ static void AllocateFlags() {
bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
flag_values->xla_gpu_disable_gpuasm_optimizations(),
"In XLA:GPU run ptxas in -O0 (default is -O3)."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_asm_extra_flags",
string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "",
"Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). "
"If multiple parameters, separate them by comma."));
flag_objects->push_back(tensorflow::Flag(
"xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
"Sets compiler fuel, useful for bisecting bugs in passes. Format "

View File

@ -222,9 +222,13 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel,
}
se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
string extra_string =
hlo_module_config.debug_options().xla_gpu_asm_extra_flags();
std::vector<std::string> extra_flags;
extra_flags = absl::StrSplit(extra_string, ",", absl::SkipEmpty());
return se::GpuAsmOpts(
hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(),
hlo_module_config.debug_options().xla_gpu_cuda_data_dir());
hlo_module_config.debug_options().xla_gpu_cuda_data_dir(), extra_flags);
}
// Unimplemented for integers yet.

View File

@ -287,7 +287,10 @@ message DebugOptions {
// memory, or have bugs.
bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138;
// Next id: 141
// Extra parameters to pass the GPU assembler.
string xla_gpu_asm_extra_flags = 141;
// Next id: 142
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.

View File

@ -214,6 +214,10 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
}
ptxas_args.insert(ptxas_args.end(), options.extra_flags.begin(),
options.extra_flags.end());
if (VLOG_IS_ON(3)) {
VLOG(3) << absl::StrJoin(ptxas_args, " ");
}
ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args);
ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR,
tensorflow::ACTION_PIPE);
@ -228,6 +232,10 @@ port::StatusOr<std::vector<uint8>> CompileGpuAsm(int cc_major, int cc_minor,
absl::StrFormat("ptxas exited with non-zero error code %d, output: %s",
exit_status, stderr_output));
}
// Print the verbose output of ptxas.
if (!stderr_output.empty()) {
VLOG(2) << stderr_output;
}
// Read in the result of compilation and return it as a byte vector.
std::string cubin;