Merge pull request #47121 from nouiz:upstream-llvm_file
PiperOrigin-RevId: 359719047 Change-Id: I5839d5a53a15506a80dc2719ed6fb172b716aeac
This commit is contained in:
commit
362f9fc585
@ -173,6 +173,12 @@ static void AllocateFlags() {
|
||||
return true;
|
||||
};
|
||||
|
||||
// Custom "sub-parser" lambda for xla_gpu_llvm_ir_file.
|
||||
auto setter_for_xla_gpu_llvm_ir_file = [](const string& value) {
|
||||
flag_values->add_xla_gpu_llvm_ir_file(value);
|
||||
return true;
|
||||
};
|
||||
|
||||
// Custom "sub-parser" lambda for xla_backend_extra_options.
|
||||
auto setter_for_xla_backend_extra_options =
|
||||
[](string comma_separated_values) {
|
||||
@ -370,7 +376,15 @@ static void AllocateFlags() {
|
||||
"If non-empty, specifies a file containing ptx to use. The filename "
|
||||
"prefix must have the same pattern as PTX dumped by XLA. This allows to "
|
||||
"match one specific module. General workflow. Get the generated module "
|
||||
"ptx from XLA. Modify it. Then pass it back via this option."));
|
||||
"ptx from XLA, modify it, then pass it back via this option."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "",
|
||||
"If non-empty, specifies a file containing textual LLVM IR to use. The "
|
||||
"filename prefix must have the same pattern as LLVM dumped by XLA "
|
||||
"(i.e. module_0001.ir-no-opt.ll -> module_0001.MY_NEW_FILE.ll). This "
|
||||
"allows to match one specific module. General workflow. Get the not "
|
||||
"optimized LLVM IR from XLA, modify it, then pass it back via this "
|
||||
"option."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_test_all_output_layouts",
|
||||
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
||||
|
@ -1405,6 +1405,8 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
"@llvm-project//llvm:IRReader",
|
||||
"@llvm-project//llvm:Support",
|
||||
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
|
||||
)
|
||||
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
|
||||
@ -202,7 +204,7 @@ absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
// Try to load ptx from files defined in the FLAGS. If successful, return true.
|
||||
bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
|
||||
const HloModule* module, std::string* ptx) {
|
||||
// If the xla_gpu_ptx_file options is set, be explicit when a file is used
|
||||
// If the xla_gpu_ptx_file option is set, be explicit if a file is used
|
||||
// and warn when a file is not used to ease catching typo in filename.
|
||||
std::string prefix = xla::FilenameFor(*module, "", *ptx);
|
||||
std::string matched_filename;
|
||||
@ -234,6 +236,50 @@ bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to load textual LLVM IR from files defined in the FLAGS. If
|
||||
// successful, return the llvm::Module, otherwise return nullptr.
|
||||
std::unique_ptr<llvm::Module> MaybeLoadLLVMFromFile(const HloModule* module,
|
||||
llvm::Module* llvm_module) {
|
||||
// If the xla_gpu_llvm_ir_file option is set, be explicit if a file is used
|
||||
// and warn when a file is not used to ease catching typo in filename.
|
||||
if (module == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string prefix = xla::FilenameFor(*module, "", "");
|
||||
auto xla_gpu_llvm_ir_file =
|
||||
module->config().debug_options().xla_gpu_llvm_ir_file();
|
||||
auto matched_filename = absl::c_find_if(
|
||||
xla_gpu_llvm_ir_file, [prefix](const string& full_filename) {
|
||||
// To ease comparing many LLVM versions, accept different suffixes then
|
||||
// the original filename.
|
||||
return absl::StartsWith(tensorflow::io::Basename(full_filename),
|
||||
prefix);
|
||||
});
|
||||
if (!xla_gpu_llvm_ir_file.empty() &&
|
||||
matched_filename == std::end(xla_gpu_llvm_ir_file)) {
|
||||
VLOG(0) << "RunBackend() - For module with prefix '" << prefix
|
||||
<< "', we did not found a LLVM file to load.";
|
||||
}
|
||||
|
||||
if (matched_filename != std::end(xla_gpu_llvm_ir_file)) {
|
||||
VLOG(0) << "RunBackend() - Will load LLVM from file: " << *matched_filename;
|
||||
llvm::LLVMContext& context = llvm_module->getContext();
|
||||
llvm::SMDiagnostic err;
|
||||
std::unique_ptr<llvm::Module> loaded_module =
|
||||
llvm::parseIRFile(*matched_filename, err, context);
|
||||
|
||||
if (!loaded_module) {
|
||||
err.print("ERR", llvm::errs());
|
||||
LOG(FATAL) << "Failed to load an LLVM file. It is probably invalid LLVM.";
|
||||
}
|
||||
// Overwrite the dumped not optimized LLVM to show which one will be used.
|
||||
llvm_ir::DumpIrIfEnabled(*module, *loaded_module, /*optimized=*/false);
|
||||
return loaded_module;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
|
||||
@ -320,13 +366,21 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
|
||||
libdevice_dir = cached_libdevice_dir_;
|
||||
}
|
||||
VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
|
||||
std::unique_ptr<llvm::Module> loaded_module =
|
||||
MaybeLoadLLVMFromFile(debug_module, llvm_module);
|
||||
llvm::Module* selected_module = nullptr;
|
||||
if (loaded_module) {
|
||||
selected_module = loaded_module.get();
|
||||
} else {
|
||||
selected_module = llvm_module;
|
||||
}
|
||||
|
||||
string ptx;
|
||||
if (!(debug_module &&
|
||||
MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
"NVPTXCompiler::CompileTargetBinary - CompileToPtx");
|
||||
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(llvm_module, gpu_version,
|
||||
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(selected_module, gpu_version,
|
||||
module_config, libdevice_dir));
|
||||
}
|
||||
|
||||
|
@ -314,7 +314,10 @@ message DebugOptions {
|
||||
// Compilation errors out if these ops are encountered.
|
||||
bool xla_gpu_deterministic_ops = 148;
|
||||
|
||||
// Next id: 150
|
||||
// Paths to files with LLVM code.
|
||||
repeated string xla_gpu_llvm_ir_file = 150;
|
||||
|
||||
// Next id: 151
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
Loading…
Reference in New Issue
Block a user