diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 703159a25cb..6d5e864289c 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -129,14 +129,6 @@ AMDGPUCompiler::CompileTargetBinary(const HloModuleConfig& module_config, rocdl_dir_)); } - if (debug_module) { - llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, /*optimized=*/false); - } - - if (user_post_optimization_hook_) { - user_post_optimization_hook_(*llvm_module); - } - return std::pair>("", std::move(hsaco)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3a17630fea2..2c51bacb32d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -661,8 +661,8 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, const auto compile_single_module = [this, stream_exec, &module_config, debug_module]( - llvm::Module* llvm_module, - bool relocatable) -> StatusOr { + llvm::Module* llvm_module, bool relocatable, + absl::optional shard_number) -> StatusOr { { XLA_SCOPED_LOGGING_TIMER( "GpuCompiler::RunBackend - Running LLVM verifier"); @@ -682,8 +682,55 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, : "."); } GpuVersion gpu_version = GetGpuVersion(stream_exec); - return CompileTargetBinary(module_config, llvm_module, gpu_version, - stream_exec, relocatable, debug_module); + StatusOr>> result = + CompileTargetBinary(module_config, llvm_module, gpu_version, + stream_exec, relocatable, debug_module); + + if (!result.ok()) { + return result; + } + + if (DumpingEnabledForHloModule(*debug_module)) { + if (debug_module) { + if (shard_number.has_value()) { + llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, + /*optimized=*/true, + std::to_string(*shard_number)); + } else { + llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, + /*optimized=*/true); + } + } else { + LOG(ERROR) + << "Dumping is not implemented since the file name cannot be " + "inferred. Please implement (potentially MLIR) module -> " + "filename heuristic."; + } + } + + if (user_post_optimization_hook_) { + user_post_optimization_hook_(*llvm_module); + } + + // Write PTX to IR dump directory, if IR dumping was requested. + if (DumpingEnabledForHloModule(*debug_module)) { + absl::string_view ptx = result->first; + if (debug_module) { + if (shard_number.has_value()) { + DumpToFileInDirOrStdout(*debug_module, "", + std::to_string(*shard_number) + ".ptx", ptx); + } else { + DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx); + } + } else { + LOG(ERROR) + << "Dumping is not implemented since the file name cannot be " + "inferred. Please implement (potentially MLIR) module -> " + "filename heuristic."; + } + } + + return result; }; tensorflow::thread::ThreadPool* thread_pool = options.thread_pool; @@ -698,13 +745,15 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, } if (!thread_pool) { - return compile_single_module(llvm_module.get(), /*relocatable=*/false); + return compile_single_module(llvm_module.get(), /*relocatable=*/false, + /*shard_number=*/absl::nullopt); } // Test whether LinkModules is supported. if (this->LinkModules(stream_exec, {}).status().code() == tensorflow::error::Code::UNIMPLEMENTED) { - return compile_single_module(llvm_module.get(), /*relocatable=*/false); + return compile_single_module(llvm_module.get(), /*relocatable=*/false, + /*shard_number=*/absl::nullopt); } std::vector> llvm_modules; @@ -756,8 +805,8 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config, new_llvm_module = llvm::parseAssemblyString(ir, err, context); } - compile_results[i] = - compile_single_module(new_llvm_module.get(), /*relocatable=*/true); + compile_results[i] = compile_single_module( + new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i); counter.DecrementCount(); }); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index a57c7ca812c..4fdd0f3d80d 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -265,13 +265,14 @@ void FeedLLVMWithFlags(const std::vector& cl_opts) { } // Returns whether the module could use any device bitcode library functions. -// This function may have false positives -- the module might not use libdevice -// on NVPTX or ROCm-Device-Libs on AMDGPU even if this function returns true. bool CouldNeedDeviceBitcode(const llvm::Module& module) { for (const llvm::Function& function : module.functions()) { - // This is a conservative approximation -- not all such functions are in - // libdevice or ROCm-Device-Libs. - if (!function.isIntrinsic() && function.isDeclaration()) { + // The list of prefixes should be in sync with library functions used in + // target_util.cc. + if (!function.isIntrinsic() && function.isDeclaration() && + (function.getName().startswith("__nv_") || + function.getName().startswith("__ocml_") || + function.getName().startswith("__ockl_"))) { return true; } } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 77fbdcd3672..eb605777e7c 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -330,31 +330,6 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, module_config, libdevice_dir)); } - if (DumpingEnabledForHloModule(*debug_module)) { - if (debug_module) { - llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, /*optimized=*/true); - } else { - LOG(ERROR) << "Dumping is not implemented since the file name cannot be " - "inferred. Please implement (potentially MLIR) module -> " - "filename heuristic."; - } - } - - if (user_post_optimization_hook_) { - user_post_optimization_hook_(*llvm_module); - } - - // Write PTX to IR dump directory, if IR dumping was requested. - if (DumpingEnabledForHloModule(*debug_module)) { - if (debug_module) { - DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx); - } else { - LOG(ERROR) << "Dumping is not implemented since the file name cannot be " - "inferred. Please implement (potentially MLIR) module -> " - "filename heuristic."; - } - } - std::vector cubin = CompileGpuAsmOrGetCachedResult( stream_exec, ptx, compute_capability.first, compute_capability.second, module_config, relocatable); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 584b47b9f6c..a00156a0e4a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -576,7 +576,8 @@ static Status CreateAndWriteStringToFile(const string& directory_name, } void DumpIrIfEnabled(const HloModule& hlo_module, - const llvm::Module& llvm_module, bool optimized) { + const llvm::Module& llvm_module, bool optimized, + absl::string_view filename_suffix) { const auto& debug_opts = hlo_module.config().debug_options(); if (!DumpingEnabledForHloModule(hlo_module)) { return; @@ -585,8 +586,11 @@ void DumpIrIfEnabled(const HloModule& hlo_module, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt"); - DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"), - DumpModuleToString(llvm_module)); + DumpToFileInDirOrStdout( + hlo_module, "", + absl::StrCat(suffix, filename_suffix.empty() ? "" : ".", filename_suffix, + ".ll"), + DumpModuleToString(llvm_module)); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. Skip diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index c0a55e4da33..3a3b4b77d70 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -272,7 +272,8 @@ std::map MergeMetadata( // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix // of "-no-opt.ll" is used. void DumpIrIfEnabled(const HloModule& hlo_module, - const llvm::Module& llvm_module, bool optimized); + const llvm::Module& llvm_module, bool optimized, + absl::string_view filename_suffix = ""); llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 8dfb094e168..562ba4e760b 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -59,10 +59,6 @@ class GpuDummyCompiler : public GpuCompiler { const HloModuleConfig& module_config, llvm::Module* llvm_module, GpuVersion gpu_version, se::StreamExecutor* stream_exec, bool relocatable, const HloModule* debug_module) { - if (user_post_optimization_hook_) { - user_post_optimization_hook_(*llvm_module); - } - std::vector compiled_results; return std::pair>( "", std::move(compiled_results));