[XLA/GPU] Fix issues brought by multi-threaded compilation:

* Improve the heuristics on libdevice usage detection.
* Dump per-thread LLVM modules correctly.

PiperOrigin-RevId: 350412026
Change-Id: I84e30b414b9d9fbbc67bb0710379b06966cf5f9e
This commit is contained in:
Tim Shen 2021-01-06 13:03:06 -08:00 committed by TensorFlower Gardener
parent 34572c38f6
commit 6b351bf94b
7 changed files with 72 additions and 54 deletions

View File

@ -129,14 +129,6 @@ AMDGPUCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
rocdl_dir_));
}
if (debug_module) {
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, /*optimized=*/false);
}
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
}

View File

@ -661,8 +661,8 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
const auto compile_single_module =
[this, stream_exec, &module_config, debug_module](
llvm::Module* llvm_module,
bool relocatable) -> StatusOr<BackendCompileResult> {
llvm::Module* llvm_module, bool relocatable,
absl::optional<int> shard_number) -> StatusOr<BackendCompileResult> {
{
XLA_SCOPED_LOGGING_TIMER(
"GpuCompiler::RunBackend - Running LLVM verifier");
@ -682,8 +682,55 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
: ".");
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);
return CompileTargetBinary(module_config, llvm_module, gpu_version,
stream_exec, relocatable, debug_module);
StatusOr<std::pair<std::string, std::vector<uint8>>> result =
CompileTargetBinary(module_config, llvm_module, gpu_version,
stream_exec, relocatable, debug_module);
if (!result.ok()) {
return result;
}
if (DumpingEnabledForHloModule(*debug_module)) {
if (debug_module) {
if (shard_number.has_value()) {
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
/*optimized=*/true,
std::to_string(*shard_number));
} else {
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module,
/*optimized=*/true);
}
} else {
LOG(ERROR)
<< "Dumping is not implemented since the file name cannot be "
"inferred. Please implement (potentially MLIR) module -> "
"filename heuristic.";
}
}
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
// Write PTX to IR dump directory, if IR dumping was requested.
if (DumpingEnabledForHloModule(*debug_module)) {
absl::string_view ptx = result->first;
if (debug_module) {
if (shard_number.has_value()) {
DumpToFileInDirOrStdout(*debug_module, "",
std::to_string(*shard_number) + ".ptx", ptx);
} else {
DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
}
} else {
LOG(ERROR)
<< "Dumping is not implemented since the file name cannot be "
"inferred. Please implement (potentially MLIR) module -> "
"filename heuristic.";
}
}
return result;
};
tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
@ -698,13 +745,15 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
}
if (!thread_pool) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
return compile_single_module(llvm_module.get(), /*relocatable=*/false,
/*shard_number=*/absl::nullopt);
}
// Test whether LinkModules is supported.
if (this->LinkModules(stream_exec, {}).status().code() ==
tensorflow::error::Code::UNIMPLEMENTED) {
return compile_single_module(llvm_module.get(), /*relocatable=*/false);
return compile_single_module(llvm_module.get(), /*relocatable=*/false,
/*shard_number=*/absl::nullopt);
}
std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
@ -756,8 +805,8 @@ GpuCompiler::CompileToTargetBinary(const HloModuleConfig& module_config,
new_llvm_module = llvm::parseAssemblyString(ir, err, context);
}
compile_results[i] =
compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
compile_results[i] = compile_single_module(
new_llvm_module.get(), /*relocatable=*/true, /*shard_number=*/i);
counter.DecrementCount();
});
}

View File

@ -265,13 +265,14 @@ void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
}
// Returns whether the module could use any device bitcode library functions.
// This function may have false positives -- the module might not use libdevice
// on NVPTX or ROCm-Device-Libs on AMDGPU even if this function returns true.
bool CouldNeedDeviceBitcode(const llvm::Module& module) {
for (const llvm::Function& function : module.functions()) {
// This is a conservative approximation -- not all such functions are in
// libdevice or ROCm-Device-Libs.
if (!function.isIntrinsic() && function.isDeclaration()) {
// The list of prefixes should be in sync with library functions used in
// target_util.cc.
if (!function.isIntrinsic() && function.isDeclaration() &&
(function.getName().startswith("__nv_") ||
function.getName().startswith("__ocml_") ||
function.getName().startswith("__ockl_"))) {
return true;
}
}

View File

@ -330,31 +330,6 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
module_config, libdevice_dir));
}
if (DumpingEnabledForHloModule(*debug_module)) {
if (debug_module) {
llvm_ir::DumpIrIfEnabled(*debug_module, *llvm_module, /*optimized=*/true);
} else {
LOG(ERROR) << "Dumping is not implemented since the file name cannot be "
"inferred. Please implement (potentially MLIR) module -> "
"filename heuristic.";
}
}
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
// Write PTX to IR dump directory, if IR dumping was requested.
if (DumpingEnabledForHloModule(*debug_module)) {
if (debug_module) {
DumpToFileInDirOrStdout(*debug_module, "", "ptx", ptx);
} else {
LOG(ERROR) << "Dumping is not implemented since the file name cannot be "
"inferred. Please implement (potentially MLIR) module -> "
"filename heuristic.";
}
}
std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
stream_exec, ptx, compute_capability.first, compute_capability.second,
module_config, relocatable);

View File

@ -576,7 +576,8 @@ static Status CreateAndWriteStringToFile(const string& directory_name,
}
void DumpIrIfEnabled(const HloModule& hlo_module,
const llvm::Module& llvm_module, bool optimized) {
const llvm::Module& llvm_module, bool optimized,
absl::string_view filename_suffix) {
const auto& debug_opts = hlo_module.config().debug_options();
if (!DumpingEnabledForHloModule(hlo_module)) {
return;
@ -585,8 +586,11 @@ void DumpIrIfEnabled(const HloModule& hlo_module,
// XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously
// dumped from the same process in such cases.
string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
DumpToFileInDirOrStdout(hlo_module, "", absl::StrCat(suffix, ".ll"),
DumpModuleToString(llvm_module));
DumpToFileInDirOrStdout(
hlo_module, "",
absl::StrCat(suffix, filename_suffix.empty() ? "" : ".", filename_suffix,
".ll"),
DumpModuleToString(llvm_module));
// For some models the embedded constants can be huge, so also dump the module
// with the constants stripped to get IR that is easier to manipulate. Skip

View File

@ -272,7 +272,8 @@ std::map<int, llvm::MDNode*> MergeMetadata(
// If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
// of "-no-opt.ll" is used.
void DumpIrIfEnabled(const HloModule& hlo_module,
const llvm::Module& llvm_module, bool optimized);
const llvm::Module& llvm_module, bool optimized,
absl::string_view filename_suffix = "");
llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
llvm::GlobalValue::LinkageTypes linkage,

View File

@ -59,10 +59,6 @@ class GpuDummyCompiler : public GpuCompiler {
const HloModuleConfig& module_config, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec, bool relocatable,
const HloModule* debug_module) {
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
std::vector<uint8> compiled_results;
return std::pair<std::string, std::vector<uint8>>(
"", std::move(compiled_results));